|
|
|
|
|
|
|
|
""" |
|
|
This module contains classes and utilities for building variable trackers in Dynamo. |
|
|
Variable trackers are used to convert Python values into symbolic representations |
|
|
that can be traced and transformed during graph capture. |
|
|
|
|
|
The key classes are: |
|
|
|
|
|
- VariableBuilder: Handles source-tracked objects that need guards and proper |
|
|
reconstruction in the output graph. Used for inputs, module attributes, etc. |
|
|
|
|
|
- SourcelessBuilder: Handles ephemeral objects created during tracing that don't |
|
|
need source tracking or guards. Used for temporary lists, intermediate values, etc. |
|
|
|
|
|
Variable trackers enable Dynamo to track the flow of values through the program, |
|
|
maintain guards for dynamic properties, and reconstruct values in the output graph. |
|
|
The builders in this module handle converting Python values into appropriate |
|
|
VariableTracker instances based on their type and usage context. |
|
|
""" |
|
|
|
|
|
import abc |
|
|
import collections |
|
|
import contextlib |
|
|
import copy |
|
|
import dataclasses |
|
|
import enum |
|
|
import functools |
|
|
import inspect |
|
|
import itertools |
|
|
import logging |
|
|
import math |
|
|
import operator |
|
|
import random |
|
|
import re |
|
|
import sys |
|
|
import traceback |
|
|
import types |
|
|
import weakref |
|
|
from collections.abc import MutableMapping |
|
|
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union |
|
|
|
|
|
import sympy |
|
|
|
|
|
import torch |
|
|
from torch import SymInt |
|
|
from torch._dispatch.python import enable_python_dispatcher |
|
|
from torch._dynamo.utils import ( |
|
|
get_metrics_context, |
|
|
is_int_specialization_case, |
|
|
is_torch_sym, |
|
|
set_feature_use, |
|
|
) |
|
|
from torch._guards import TracingContext |
|
|
from torch._higher_order_ops.flat_apply import flat_apply |
|
|
from torch._higher_order_ops.torchbind import call_torchbind |
|
|
from torch._ops import HigherOrderOperator |
|
|
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode |
|
|
from torch._subclasses.meta_utils import is_sparse_any, safe_grad |
|
|
from torch._utils_internal import justknobs_check |
|
|
from torch.fx.experimental._backward_state import BackwardState |
|
|
from torch.fx.experimental._dynamism import normalize_source_name |
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
|
_constrain_range_for_size, |
|
|
_nested_int_aware_sort, |
|
|
DimDynamic, |
|
|
RelaxedUnspecConstraint, |
|
|
StatefulSymbolicContext, |
|
|
SubclassSymbolicContext, |
|
|
SymbolicContext, |
|
|
SymIntSymbolicContext, |
|
|
TrackedFake, |
|
|
) |
|
|
from torch.fx.immutable_collections import immutable_dict, immutable_list |
|
|
from torch.nn.utils._expanded_weights import ExpandedWeight |
|
|
from torch.utils._python_dispatch import ( |
|
|
is_traceable_wrapper_subclass, |
|
|
is_traceable_wrapper_subclass_type, |
|
|
) |
|
|
from torch.utils._sympy.value_ranges import ValueRanges |
|
|
from torch.utils.weak import TensorWeakRef |
|
|
|
|
|
from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules |
|
|
from ..device_interface import get_registered_device_interfaces |
|
|
from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2 |
|
|
from ..guards import GuardBuilder, install_guard, make_dupe_guard |
|
|
from ..pgo import ( |
|
|
auto_dynamic, |
|
|
auto_unset, |
|
|
FrameStateSizeEntry, |
|
|
InferStride, |
|
|
process_automatic_dynamic, |
|
|
) |
|
|
from ..side_effects import SideEffects |
|
|
from ..source import ( |
|
|
AttrProxySource, |
|
|
AttrSource, |
|
|
CallMethodItemSource, |
|
|
ChainedSource, |
|
|
ConstDictKeySource, |
|
|
ConvertIntSource, |
|
|
DictGetItemSource, |
|
|
DictSubclassGetItemSource, |
|
|
FloatTensorSource, |
|
|
GetItemSource, |
|
|
GradSource, |
|
|
is_constant_source, |
|
|
is_from_closure_source, |
|
|
is_from_global_source, |
|
|
is_from_nonlocal_source, |
|
|
is_from_optimizer_source, |
|
|
is_from_unspecialized_nn_module_source, |
|
|
ListGetItemSource, |
|
|
LocalSource, |
|
|
NonSerializableSetGetItemSource, |
|
|
NumpyTensorSource, |
|
|
OptimizerSource, |
|
|
RandomValueSource, |
|
|
Source, |
|
|
SubclassAttrListSource, |
|
|
TupleIteratorGetItemSource, |
|
|
UnspecializedBuiltinNNModuleSource, |
|
|
UnspecializedNNModuleSource, |
|
|
) |
|
|
from ..utils import ( |
|
|
_extract_tensor_dict, |
|
|
build_checkpoint_variable, |
|
|
build_invoke_subgraph_variable, |
|
|
clone_input, |
|
|
common_constant_types, |
|
|
dict_keys, |
|
|
get_fake_value, |
|
|
get_items_from_dict, |
|
|
get_locals_to_steal, |
|
|
get_static_address_type, |
|
|
is_frozen_dataclass, |
|
|
is_function, |
|
|
is_function_or_wrapper, |
|
|
is_invoke_subgraph, |
|
|
is_lru_cache_wrapped_function, |
|
|
is_namedtuple, |
|
|
is_parameter_freezing, |
|
|
is_typing, |
|
|
is_utils_checkpoint, |
|
|
is_wrapper_or_member_descriptor, |
|
|
istype, |
|
|
namedtuple_fields, |
|
|
odict_values, |
|
|
proxy_args_kwargs, |
|
|
range_iterator, |
|
|
set_example_value, |
|
|
tensor_always_has_static_shape, |
|
|
tuple_iterator, |
|
|
tuple_iterator_getitem, |
|
|
tuple_iterator_len, |
|
|
unwrap_with_attr_name_if_wrapper, |
|
|
wrap_fake_exception, |
|
|
) |
|
|
from .base import ( |
|
|
AttributeMutationNew, |
|
|
typestr, |
|
|
ValueMutationExisting, |
|
|
ValueMutationNew, |
|
|
VariableTracker, |
|
|
VariableTrackerMeta, |
|
|
) |
|
|
from .builtin import BuiltinVariable |
|
|
from .constant import ConstantVariable, EnumVariable |
|
|
from .ctx_manager import ( |
|
|
AutocastModeVariable, |
|
|
DynamoConfigPatchVariable, |
|
|
ErrorOnGraphBreakVariable, |
|
|
EventVariable, |
|
|
NullContextVariable, |
|
|
PreserveVersionContextVariable, |
|
|
StreamContextVariable, |
|
|
StreamVariable, |
|
|
) |
|
|
from .dicts import ( |
|
|
ConstDictVariable, |
|
|
DefaultDictVariable, |
|
|
DictKeySetVariable, |
|
|
FrozensetVariable, |
|
|
MappingProxyVariable, |
|
|
SetVariable, |
|
|
) |
|
|
from .distributed import ( |
|
|
DeviceMeshVariable, |
|
|
PlacementClassVariable, |
|
|
PlacementVariable, |
|
|
ProcessGroupVariable, |
|
|
WorldMetaClassVariable, |
|
|
) |
|
|
from .functions import ( |
|
|
BuiltinMethodVariable, |
|
|
CollectionsNamedTupleFunction, |
|
|
CollectiveFunctionRewriteVariable, |
|
|
CreateTMADescriptorExperimentalVariable, |
|
|
CreateTMADescriptorStableVariable, |
|
|
FunctoolsPartialVariable, |
|
|
FunctoolsWrapsVariable, |
|
|
SysFunctionVariable, |
|
|
TracebackVariable, |
|
|
TritonKernelVariable, |
|
|
UserFunctionVariable, |
|
|
UserMethodVariable, |
|
|
WrapperUserFunctionVariable, |
|
|
) |
|
|
from .higher_order_ops import TorchHigherOrderOperatorVariable |
|
|
from .iter import ItertoolsVariable |
|
|
from .lazy import LazyVariableTracker |
|
|
from .lists import ( |
|
|
BaseListVariable, |
|
|
ListIteratorVariable, |
|
|
ListVariable, |
|
|
NamedTupleVariable, |
|
|
RangeVariable, |
|
|
SizeVariable, |
|
|
SliceVariable, |
|
|
TupleIteratorVariable, |
|
|
TupleVariable, |
|
|
) |
|
|
from .misc import ( |
|
|
AutogradEngineVariable, |
|
|
AutogradFunctionContextVariable, |
|
|
AutogradFunctionVariable, |
|
|
ComptimeVariable, |
|
|
DebuggingVariable, |
|
|
DelayGraphBreakVariable, |
|
|
GetAttrVariable, |
|
|
GetSetDescriptorVariable, |
|
|
LambdaVariable, |
|
|
LoggingLoggerVariable, |
|
|
MethodWrapperVariable, |
|
|
NumpyDTypeVariable, |
|
|
NumpyTypeInfoVariable, |
|
|
NumpyVariable, |
|
|
PythonModuleVariable, |
|
|
RandomClassVariable, |
|
|
RandomVariable, |
|
|
RegexPatternVariable, |
|
|
SavedTensorBox, |
|
|
TorchVersionVariable, |
|
|
TypingVariable, |
|
|
WeakRefVariable, |
|
|
) |
|
|
from .nn_module import ( |
|
|
FSDPManagedNNModuleVariable, |
|
|
UnspecializedBuiltinNNModuleVariable, |
|
|
UnspecializedNNModuleVariable, |
|
|
) |
|
|
from .optimizer import OptimizerVariable |
|
|
from .script_object import TorchScriptObjectVariable |
|
|
from .sdpa import SDPAParamsVariable |
|
|
from .tensor import ( |
|
|
NumpyNdarrayVariable, |
|
|
supported_const_comparison_op_values, |
|
|
SymNodeVariable, |
|
|
TensorSubclassVariable, |
|
|
TensorVariable, |
|
|
UnspecializedPythonVariable, |
|
|
) |
|
|
from .torch import ( |
|
|
DispatchKeySetVariable, |
|
|
FuncTorchInterpreterVariable, |
|
|
TorchCtxManagerClassVariable, |
|
|
TorchInGraphFunctionVariable, |
|
|
) |
|
|
from .torch_function import ( |
|
|
TensorWithTFOverrideVariable, |
|
|
torch_function_mode_stack_state_mgr, |
|
|
TorchFunctionModeVariable, |
|
|
) |
|
|
from .user_defined import ( |
|
|
FrozenDataClassVariable, |
|
|
IntWrapperVariable, |
|
|
KeyedJaggedTensorVariable, |
|
|
MutableMappingVariable, |
|
|
SourcelessGraphModuleVariable, |
|
|
UserDefinedClassVariable, |
|
|
UserDefinedDictVariable, |
|
|
UserDefinedExceptionClassVariable, |
|
|
UserDefinedListVariable, |
|
|
UserDefinedObjectVariable, |
|
|
UserDefinedSetVariable, |
|
|
UserDefinedTupleVariable, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
import numpy as np |
|
|
except ModuleNotFoundError: |
|
|
np = None |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from torch._dynamo.codegen import PyCodegen |
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator |
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
static_inputs_log = torch._logging.getArtifactLogger( |
|
|
__name__, "cudagraph_static_inputs" |
|
|
) |
|
|
|
|
|
|
|
|
DimList = list |
|
|
|
|
|
|
|
|
def safe_has_grad(t): |
|
|
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): |
|
|
return hasattr(t, "grad") |
|
|
|
|
|
|
|
|
class _missing: |
|
|
pass |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class GraphArg: |
|
|
source: Source |
|
|
|
|
|
|
|
|
|
|
|
_example: Union[TensorWeakRef, torch.SymInt] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass_arg_as_tensor: bool |
|
|
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_tensor: bool = True |
|
|
|
|
|
|
|
|
|
|
|
example_strong_ref: Optional[torch.Tensor] = None |
|
|
|
|
|
@property |
|
|
def example(self): |
|
|
if isinstance(self._example, TensorWeakRef): |
|
|
r = self._example() |
|
|
assert r is not None |
|
|
return r |
|
|
else: |
|
|
return self._example |
|
|
|
|
|
def __post_init__(self): |
|
|
if isinstance(self._example, torch.Tensor): |
|
|
self._example = TensorWeakRef(self._example) |
|
|
assert is_fake(self.fake_tensor) |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen(self.source) |
|
|
|
|
|
def erase(self): |
|
|
self._example = None |
|
|
self.example_strong_ref = None |
|
|
|
|
|
def __eq__(self, other): |
|
|
return self.source.name() == other.source.name() |
|
|
|
|
|
|
|
|
class BackwardStateGraphArg(GraphArg): |
|
|
def __init__(self) -> None: |
|
|
super().__init__( |
|
|
source=None, |
|
|
_example=BackwardState(), |
|
|
pass_arg_as_tensor=False, |
|
|
fake_tensor=None, |
|
|
is_tensor=False, |
|
|
) |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
assert codegen.tx.output.backward_state_var |
|
|
codegen.add_push_null( |
|
|
lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") |
|
|
) |
|
|
codegen.call_function(0, False) |
|
|
codegen.dup_top() |
|
|
codegen.store(codegen.tx.output.backward_state_var) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset( |
|
|
id(member) |
|
|
for name, member in vars(itertools).items() |
|
|
if not name.startswith("_") and inspect.isclass(member) |
|
|
) |
|
|
|
|
|
ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers |
|
|
og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters |
|
|
|
|
|
|
|
|
class VariableBuilder: |
|
|
"""Wrap a python value in a VariableTracker() instance""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tx, |
|
|
source: Source, |
|
|
) -> None: |
|
|
assert source is not None, ( |
|
|
"Consider SourcelessBuilder for ephemeral objects, usually objects created locally." |
|
|
) |
|
|
assert TracingContext.try_get() is not None, "Expected active TracingContext" |
|
|
super().__init__() |
|
|
self.tx = tx |
|
|
self.source = source |
|
|
self.name = source.name() |
|
|
|
|
|
def __call__(self, value): |
|
|
if value in self.tx.output.side_effects: |
|
|
side_effect_result = self.tx.output.side_effects[value] |
|
|
dup_guard = make_dupe_guard(self.source, side_effect_result.source) |
|
|
if dup_guard: |
|
|
self.install_guards(dup_guard) |
|
|
return side_effect_result |
|
|
|
|
|
cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) |
|
|
if cached_vt: |
|
|
return cached_vt |
|
|
|
|
|
vt = self._wrap(value) |
|
|
|
|
|
if vt.source is None: |
|
|
vt.source = self.source |
|
|
|
|
|
def _is_deduplicable_sym_variable(value, vt): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return is_torch_sym(value) and isinstance(vt, SymNodeVariable) |
|
|
|
|
|
if ( |
|
|
( |
|
|
self._can_lift_attrs_to_inputs(vt) |
|
|
or _is_deduplicable_sym_variable(value, vt) |
|
|
) |
|
|
and value not in self.tx.output.side_effects |
|
|
and not is_wrapper_or_member_descriptor(value) |
|
|
): |
|
|
vt = self.tx.output.side_effects.track_object_existing(value, vt) |
|
|
|
|
|
self.tx.output.variable_tracker_cache.add(value, self.source, vt) |
|
|
return vt |
|
|
|
|
|
def _can_lift_attrs_to_inputs(self, vt): |
|
|
return type(vt) in { |
|
|
TensorVariable, |
|
|
TensorWithTFOverrideVariable, |
|
|
UserDefinedObjectVariable, |
|
|
NumpyNdarrayVariable, |
|
|
} |
|
|
|
|
|
def get_source(self): |
|
|
return self.source |
|
|
|
|
|
def install_guards(self, *guards): |
|
|
source = self.get_source() |
|
|
try: |
|
|
tmp = [source.make_guard(guard) for guard in guards] |
|
|
except NotImplementedError: |
|
|
return None |
|
|
install_guard(*tmp, skip=1) |
|
|
return {} |
|
|
|
|
|
@classmethod |
|
|
def _type_dispatch(cls): |
|
|
return cls._type_dispatch_impl(config.trace_numpy) |
|
|
|
|
|
@classmethod |
|
|
@functools.cache |
|
|
def _type_dispatch_impl(cls, trace_numpy): |
|
|
|
|
|
entries = [ |
|
|
( |
|
|
( |
|
|
torch.Tensor, |
|
|
torch.nn.Parameter, |
|
|
torch._subclasses.FakeTensor, |
|
|
torch._subclasses.functional_tensor.FunctionalTensor, |
|
|
), |
|
|
cls.wrap_tensor, |
|
|
), |
|
|
( |
|
|
(tuple, list, odict_values, collections.deque, torch.Size), |
|
|
cls.wrap_listlike, |
|
|
), |
|
|
(tuple_iterator, cls.wrap_tuple_iterator), |
|
|
(range_iterator, cls.wrap_range_iterator), |
|
|
((slice, range), cls.wrap_slice_range), |
|
|
(tuple(common_constant_types), cls.wrap_literal), |
|
|
(re.Pattern, cls.wrap_regex_pattern), |
|
|
(weakref.ReferenceType, cls.wrap_weakref), |
|
|
(torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), |
|
|
(torch.jit.ScriptFunction, cls.wrap_jit_function), |
|
|
(types.MappingProxyType, cls.wrap_mapping_proxy), |
|
|
] |
|
|
|
|
|
if trace_numpy and np: |
|
|
entries.append((np.ndarray, cls.wrap_numpy_ndarray)) |
|
|
|
|
|
result = {} |
|
|
for ts, fn in entries: |
|
|
for t in ts if isinstance(ts, tuple) else (ts,): |
|
|
assert t not in result |
|
|
result[t] = fn |
|
|
|
|
|
return result |
|
|
|
|
|
def wrap_regex_pattern(self, value: re.Pattern): |
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return RegexPatternVariable(value) |
|
|
|
|
|
def wrap_weakref(self, value: weakref.ReferenceType): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return WeakRefVariable.build(self.tx, value, source=self.source) |
|
|
|
|
|
def wrap_removable_handle(self, value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to represent unregistered RemovableHandle", |
|
|
context="", |
|
|
explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, " |
|
|
"which is not supported. This happens because the RemovableHandle was created in another frame.", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
def wrap_jit_function(self, value): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return WrapperUserFunctionVariable( |
|
|
value, "_torchdynamo_inline", source=self.source |
|
|
) |
|
|
|
|
|
def wrap_mapping_proxy(self, value): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK) |
|
|
all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) |
|
|
|
|
|
if not all_const: |
|
|
unimplemented_v2( |
|
|
gb_type="non-const keys in mappingproxy", |
|
|
context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", |
|
|
explanation="Dynamo expects mappingproxy keys to be constants.", |
|
|
hints=[ |
|
|
"Ensure your mappingproxy keys are constants (e.g. int, float, strings)", |
|
|
], |
|
|
) |
|
|
|
|
|
def build_key_value(k, v): |
|
|
key = ConstantVariable.create(k) |
|
|
source_key = k |
|
|
|
|
|
source_value = GetItemSource(self.get_source(), source_key) |
|
|
res_value = LazyVariableTracker.create(v, source_value) |
|
|
|
|
|
return key, res_value |
|
|
|
|
|
items = dict(build_key_value(k, v) for k, v in value.items()) |
|
|
|
|
|
|
|
|
dict_vt = ConstDictVariable(items, source=None) |
|
|
result = MappingProxyVariable(dict_vt, source=self.source) |
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
|
|
|
@classmethod |
|
|
@functools.cache |
|
|
def _id_dispatch( |
|
|
cls, |
|
|
) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: |
|
|
from ..comptime import comptime |
|
|
|
|
|
entries = [ |
|
|
(comptime, lambda self, value: ComptimeVariable()), |
|
|
( |
|
|
dataclasses.fields, |
|
|
lambda self, value: LambdaVariable( |
|
|
_dataclasses_fields_lambda, |
|
|
source=self.source, |
|
|
**self.install_guards(GuardBuilder.FUNCTION_MATCH), |
|
|
), |
|
|
), |
|
|
(torch.__version__, lambda self, value: TorchVersionVariable()), |
|
|
] |
|
|
|
|
|
result = {} |
|
|
for ts, fn in entries: |
|
|
for t in ts if isinstance(ts, (tuple, list)) else (ts,): |
|
|
assert t not in result |
|
|
result[id(t)] = fn |
|
|
|
|
|
return result |
|
|
|
|
|
def _wrap(self, value): |
|
|
|
|
|
from torch.utils._triton import ( |
|
|
has_triton, |
|
|
has_triton_experimental_host_tma, |
|
|
has_triton_tensor_descriptor_host_tma, |
|
|
) |
|
|
|
|
|
from ..decorators import ( |
|
|
DynamoConfigPatchProxy, |
|
|
ErrorOnGraphBreakDecoratorContextManager, |
|
|
) |
|
|
|
|
|
if has_triton(): |
|
|
from triton.runtime.autotuner import Autotuner |
|
|
from triton.runtime.jit import JITFunction |
|
|
else: |
|
|
|
|
|
class JITFunction: |
|
|
pass |
|
|
|
|
|
class Autotuner: |
|
|
pass |
|
|
|
|
|
|
|
|
def create_1d_tma_descriptor(): |
|
|
pass |
|
|
|
|
|
def create_2d_tma_descriptor(): |
|
|
pass |
|
|
|
|
|
class TensorDescriptor: |
|
|
@staticmethod |
|
|
def from_tensor(): |
|
|
pass |
|
|
|
|
|
if has_triton_experimental_host_tma(): |
|
|
from triton.tools.experimental_descriptor import ( |
|
|
create_1d_tma_descriptor, |
|
|
create_2d_tma_descriptor, |
|
|
) |
|
|
if has_triton_tensor_descriptor_host_tma(): |
|
|
from triton.tools.tensor_descriptor import TensorDescriptor |
|
|
|
|
|
|
|
|
type_dispatch = self._type_dispatch().get(type(value)) |
|
|
if type_dispatch is not None: |
|
|
return type_dispatch(self, value) |
|
|
|
|
|
|
|
|
id_dispatch = self._id_dispatch().get(id(value)) |
|
|
if id_dispatch is not None: |
|
|
return id_dispatch(self, value) |
|
|
|
|
|
|
|
|
if ( |
|
|
isinstance(value, torch.Tensor) |
|
|
and type(value) |
|
|
not in ( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.nn.parameter.UninitializedBuffer, |
|
|
torch.nn.parameter.UninitializedParameter, |
|
|
ExpandedWeight, |
|
|
) |
|
|
and type(value) not in config.nontraceable_tensor_subclasses |
|
|
): |
|
|
if ( |
|
|
type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__ |
|
|
or is_traceable_wrapper_subclass(value) |
|
|
): |
|
|
return self.wrap_tensor(value) |
|
|
|
|
|
if is_namedtuple(value): |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
output = [ |
|
|
LazyVariableTracker.create( |
|
|
getattr(value, name), |
|
|
source=AttrSource(self.source, name), |
|
|
) |
|
|
for name in namedtuple_fields(type(value)) |
|
|
] |
|
|
result = NamedTupleVariable( |
|
|
output, tuple_cls=type(value), source=self.source |
|
|
) |
|
|
return result |
|
|
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not all_const: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tx.output.guard_on_key_order.add(self.source) |
|
|
|
|
|
|
|
|
|
|
|
def build_key_value(i, k, v): |
|
|
base = self.get_source() |
|
|
if all_const: |
|
|
key = ConstantVariable.create(k) |
|
|
source_key = k |
|
|
else: |
|
|
source_key = ConstDictKeySource(base, i) |
|
|
key = LazyVariableTracker.create(k, source_key) |
|
|
source_value = DictGetItemSource(base, source_key) |
|
|
res_value = LazyVariableTracker.create(v, source_value) |
|
|
|
|
|
return key, res_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = dict( |
|
|
build_key_value(i, k, v) |
|
|
for i, (k, v) in enumerate(get_items_from_dict(value)) |
|
|
) |
|
|
|
|
|
if istype(value, collections.defaultdict): |
|
|
factory_source = AttrSource(self.source, "default_factory") |
|
|
result = DefaultDictVariable( |
|
|
result, |
|
|
type(value), |
|
|
default_factory=VariableBuilder(self.tx, factory_source)( |
|
|
value.default_factory |
|
|
), |
|
|
source=self.source, |
|
|
) |
|
|
else: |
|
|
result = ConstDictVariable( |
|
|
result, user_cls=type(value), source=self.source |
|
|
) |
|
|
|
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
elif isinstance(value, torch.nn.Module): |
|
|
return self.wrap_module(value) |
|
|
elif ConstantVariable.is_literal(value): |
|
|
return self.wrap_literal(value) |
|
|
elif isinstance(value, torch.overrides.TorchFunctionMode): |
|
|
var = TorchFunctionModeVariable(value, source=self.source) |
|
|
self.tx.output.side_effects.track_object_existing(value, var) |
|
|
return var |
|
|
elif istype(value, set): |
|
|
if any(isinstance(x, torch.Tensor) for x in value): |
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap a set with tensors", |
|
|
context="Python set containing torch.Tensor elements", |
|
|
explanation=( |
|
|
"Dynamo cannot trace sets of tensors. To get a stable ordering, " |
|
|
"Dynamo needs to convert the set into a list and the order might not be " |
|
|
"stable if the set contains tensors." |
|
|
), |
|
|
hints=[ |
|
|
"Use a dictionary where the keys are tensors.", |
|
|
*graph_break_hints.SUPPORTABLE, |
|
|
], |
|
|
) |
|
|
|
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
L = list(value) |
|
|
items = [ |
|
|
LazyVariableTracker.create( |
|
|
v, source=NonSerializableSetGetItemSource(self.source, i) |
|
|
) |
|
|
for i, v in enumerate(L) |
|
|
] |
|
|
result = SetVariable(items, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif istype(value, frozenset) and all( |
|
|
( |
|
|
|
|
|
(type(x) is types.BuiltinMethodType and x.__module__ == "torch") |
|
|
or |
|
|
|
|
|
x in torch.utils._pytree.BUILTIN_TYPES |
|
|
) |
|
|
for x in value |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
items = [SourcelessBuilder.create(self.tx, v) for v in value] |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return FrozensetVariable(items, source=self.source) |
|
|
elif isinstance( |
|
|
value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) |
|
|
): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return EnumVariable(value=value, source=self.source) |
|
|
elif DebuggingVariable.is_reorderable_logging_function(value): |
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.BUILTIN_MATCH) |
|
|
return DebuggingVariable(value, source=self.source) |
|
|
elif isinstance(value, logging.Logger): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return LoggingLoggerVariable(value, source=self.source) |
|
|
elif is_utils_checkpoint(value): |
|
|
return build_checkpoint_variable(source=self.source) |
|
|
elif is_invoke_subgraph(value): |
|
|
return build_invoke_subgraph_variable(source=self.source) |
|
|
elif isinstance(value, functools.partial): |
|
|
func_src = AttrSource(self.get_source(), "func") |
|
|
func_obj = VariableBuilder(self.tx, func_src)(value.func) |
|
|
|
|
|
args = [] |
|
|
args_source = AttrSource(self.get_source(), "args") |
|
|
for i, arg in enumerate(value.args): |
|
|
args.append( |
|
|
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) |
|
|
) |
|
|
|
|
|
keywords = {} |
|
|
keywords_source = AttrSource(self.get_source(), "keywords") |
|
|
for k, v in value.keywords.items(): |
|
|
if not ConstantVariable.is_literal(k): |
|
|
unimplemented_v2( |
|
|
gb_type="functools.partial() with non-literal keyword", |
|
|
context=f"non-literal keyword: {k}", |
|
|
explanation="functools.partial() expects literal/string keywords", |
|
|
hints=[*graph_break_hints.USER_ERROR], |
|
|
) |
|
|
keywords[k] = VariableBuilder( |
|
|
self.tx, DictGetItemSource(keywords_source, k) |
|
|
)(v) |
|
|
|
|
|
install_guard( |
|
|
self.get_source().make_guard(GuardBuilder.TYPE_MATCH), |
|
|
keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), |
|
|
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), |
|
|
) |
|
|
return FunctoolsPartialVariable(func_obj, args, keywords) |
|
|
elif is_typing(value): |
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return TypingVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif np is not None and isinstance(value, np.generic): |
|
|
|
|
|
return self.wrap_numpy_ndarray(np.asarray(value)) |
|
|
elif trace_rules.is_numpy(value): |
|
|
assert np |
|
|
self.install_guards( |
|
|
GuardBuilder.FUNCTION_MATCH |
|
|
if callable(value) |
|
|
else GuardBuilder.TYPE_MATCH |
|
|
) |
|
|
return NumpyVariable(value, source=self.source) |
|
|
elif trace_rules.is_numpy_dtype(value): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return NumpyDTypeVariable(value, source=self.source) |
|
|
elif trace_rules.is_numpy_type_info(value): |
|
|
if isinstance(value, np.iinfo): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
dt_source = AttrSource(self.source, "dtype") |
|
|
install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) |
|
|
else: |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return NumpyTypeInfoVariable(value, source=self.source) |
|
|
|
|
|
elif CollectiveFunctionRewriteVariable.can_rewrite(value): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return CollectiveFunctionRewriteVariable.create( |
|
|
self.tx, |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif istype(value, torch.autograd.function.FunctionMeta): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return AutogradFunctionVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif isinstance(value, torch.autograd.function.FunctionCtx): |
|
|
actual_saved_tensors = None |
|
|
try: |
|
|
actual_saved_tensors = value.saved_tensors |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
saved_tensors = [] |
|
|
guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)] |
|
|
if isinstance(actual_saved_tensors, tuple): |
|
|
saved_tensors_source = AttrSource(self.source, "saved_tensors") |
|
|
guards.append( |
|
|
saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH) |
|
|
) |
|
|
for i, v in enumerate(actual_saved_tensors): |
|
|
saved_tensors.append( |
|
|
VariableBuilder( |
|
|
self.tx, GetItemSource(saved_tensors_source, i) |
|
|
)(v) |
|
|
) |
|
|
install_guard(*guards) |
|
|
|
|
|
return self.tx.output.side_effects.track_object_existing( |
|
|
value, |
|
|
AutogradFunctionContextVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
saved_tensors=SavedTensorBox(saved_tensors), |
|
|
), |
|
|
) |
|
|
elif ( |
|
|
isinstance(value, types.MethodType) |
|
|
and istype( |
|
|
getattr(value, "__self__", None), torch.autograd.function.FunctionMeta |
|
|
) |
|
|
and getattr(value, "__name__", "") == "apply" |
|
|
and value == getattr(value.__self__, "apply", None) |
|
|
): |
|
|
|
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return GetAttrVariable( |
|
|
AutogradFunctionVariable( |
|
|
value.__self__, source=AttrSource(self.source, member="__self__") |
|
|
), |
|
|
"apply", |
|
|
) |
|
|
elif isinstance(value, torch._C._ImperativeEngine): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return AutogradEngineVariable(value, source=self.source) |
|
|
elif ( |
|
|
value |
|
|
is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub |
|
|
): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return LambdaVariable( |
|
|
lambda: UserFunctionVariable( |
|
|
torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, |
|
|
).call_function( |
|
|
self.tx, |
|
|
(self.tx.output.side_effects.get_ca_final_callbacks_var(),), |
|
|
{}, |
|
|
) |
|
|
) |
|
|
elif isinstance(value, DynamoConfigPatchProxy): |
|
|
return DynamoConfigPatchVariable(value.changes) |
|
|
elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager): |
|
|
return ErrorOnGraphBreakVariable(value.error_on_graph_break) |
|
|
elif callable(value) and trace_rules.lookup_callable(value) is not None: |
|
|
if trace_rules.is_callable_allowed(value): |
|
|
self.tx.output.has_user_defined_allowed_in_graph = True |
|
|
return trace_rules.lookup_callable(value).create_with_source( |
|
|
value, source=self.source |
|
|
) |
|
|
elif np and isinstance(value, np.number): |
|
|
return self.wrap_unspecialized_primitive(value) |
|
|
elif isinstance(value, HigherOrderOperator): |
|
|
if value is torch._higher_order_ops.invoke_subgraph: |
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", |
|
|
context="", |
|
|
explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", |
|
|
hints=[], |
|
|
) |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) |
|
|
return TorchHigherOrderOperatorVariable.make(value, source=self.source) |
|
|
elif isinstance(value, torch.cuda.StreamContext): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
stream_source = AttrSource(self.source, "stream") |
|
|
stream_var = VariableBuilder(self.tx, stream_source)(value.stream) |
|
|
return StreamContextVariable.create(self.tx, stream_var) |
|
|
elif isinstance(value, torch.Stream): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
stream_proxy = self.tx.output.create_proxy( |
|
|
"call_function", |
|
|
type(value), |
|
|
(), |
|
|
{ |
|
|
"stream_id": value.stream_id, |
|
|
"device_index": value.device_index, |
|
|
"device_type": value.device_type, |
|
|
}, |
|
|
) |
|
|
set_example_value(stream_proxy.node, value) |
|
|
return StreamVariable( |
|
|
stream_proxy, |
|
|
value, |
|
|
value.device, |
|
|
source=self.source, |
|
|
) |
|
|
elif isinstance(value, (torch._C._SDPAParams)): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return SDPAParamsVariable.create(self.tx, value, self.source) |
|
|
elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return FuncTorchInterpreterVariable(value) |
|
|
elif isinstance(value, torch.Event): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
torch._dynamo.utils.store_user_object_weakref(value) |
|
|
event_proxy = self.tx.output.create_proxy( |
|
|
"call_function", |
|
|
torch._dynamo.utils.get_user_object_from_id, |
|
|
(id(value),), |
|
|
{}, |
|
|
) |
|
|
set_example_value(event_proxy.node, value) |
|
|
return EventVariable( |
|
|
event_proxy, |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif ( |
|
|
istype(value, contextlib.nullcontext) |
|
|
and inspect.getattr_static(value, "enter_result", None) is None |
|
|
): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return NullContextVariable(source=self.source) |
|
|
elif KeyedJaggedTensorVariable.is_matching_object(value): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = KeyedJaggedTensorVariable(value, source=self.source) |
|
|
|
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif isinstance(value, torch.optim.Optimizer): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
self.source = OptimizerSource(self.source) |
|
|
return OptimizerVariable(value, source=self.source) |
|
|
elif isinstance(value, torch.DispatchKeySet): |
|
|
self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH) |
|
|
return DispatchKeySetVariable(value) |
|
|
elif WorldMetaClassVariable.is_group_member_type(value): |
|
|
return WorldMetaClassVariable(value, source=self.source) |
|
|
elif ProcessGroupVariable.is_process_group(value): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return ProcessGroupVariable(value, source=self.source) |
|
|
elif DeviceMeshVariable.is_device_mesh(value): |
|
|
|
|
|
self.install_guards(GuardBuilder.EQUALS_MATCH) |
|
|
return DeviceMeshVariable(value, source=self.source) |
|
|
elif PlacementClassVariable.is_placement_type(value): |
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return PlacementClassVariable(value, source=self.source) |
|
|
elif PlacementVariable.is_placement(value): |
|
|
|
|
|
self.install_guards(GuardBuilder.EQUALS_MATCH) |
|
|
return PlacementVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif ( |
|
|
id(value) in ITERTOOLS_TYPE_IDS |
|
|
and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS |
|
|
): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return ItertoolsVariable(value, source=self.source) |
|
|
elif is_torch_sym(value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source = ( |
|
|
self.source |
|
|
if isinstance(value, torch.SymInt) |
|
|
else ConvertIntSource(self.source) |
|
|
) |
|
|
if value.node.has_hint(): |
|
|
new_symint = ( |
|
|
self.tx.output.shape_env.create_unspecified_symint_and_symbol( |
|
|
int(value.node.hint), |
|
|
source, |
|
|
dynamic_dim=DimDynamic.DYNAMIC, |
|
|
) |
|
|
) |
|
|
else: |
|
|
if isinstance(value, torch.SymBool): |
|
|
|
|
|
new_symint = self.tx.output.shape_env.create_unbacked_symint() |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap unbacked SymInt", |
|
|
context="", |
|
|
explanation="Unbacked SymInt input is not supported yet.", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
|
|
|
sym_node_proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(new_symint), |
|
|
new_symint, |
|
|
source=source, |
|
|
) |
|
|
|
|
|
sym_node_proxy.node.meta["grapharg"] = GraphArg( |
|
|
source, |
|
|
new_symint, |
|
|
False, |
|
|
None, |
|
|
is_tensor=False, |
|
|
example_strong_ref=new_symint, |
|
|
) |
|
|
|
|
|
sym_expr = new_symint.node.expr |
|
|
assert isinstance(sym_expr, sympy.Symbol), ( |
|
|
f"{sym_expr} is not a basic Symbol." |
|
|
) |
|
|
self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None)) |
|
|
|
|
|
tracing_symint = ( |
|
|
new_symint if isinstance(value, torch.SymInt) else new_symint == 1 |
|
|
) |
|
|
return SymNodeVariable(sym_node_proxy, tracing_symint) |
|
|
|
|
|
elif isinstance(value, (JITFunction, Autotuner)): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return TritonKernelVariable( |
|
|
value, |
|
|
None, |
|
|
None, |
|
|
source=self.source, |
|
|
) |
|
|
elif value is create_1d_tma_descriptor: |
|
|
return CreateTMADescriptorExperimentalVariable(rank=1) |
|
|
elif value is create_2d_tma_descriptor: |
|
|
return CreateTMADescriptorExperimentalVariable(rank=2) |
|
|
elif value is TensorDescriptor.from_tensor: |
|
|
return CreateTMADescriptorStableVariable() |
|
|
elif isinstance(value, torch.amp.autocast_mode.autocast): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return AutocastModeVariable( |
|
|
target_values=[ |
|
|
value.device, |
|
|
value.fast_dtype, |
|
|
value._enabled, |
|
|
value._cache_enabled, |
|
|
], |
|
|
source=self.source, |
|
|
) |
|
|
elif TorchCtxManagerClassVariable.is_matching_cls(value): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return TorchCtxManagerClassVariable(value, source=self.source) |
|
|
elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return WrapperUserFunctionVariable( |
|
|
value, "__original_fn", source=self.source |
|
|
) |
|
|
elif is_lru_cache_wrapped_function(value): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) |
|
|
elif value is traceback.clear_frames: |
|
|
return TracebackVariable(source=self.source) |
|
|
elif value is sys.exc_info or ( |
|
|
sys.version_info >= (3, 11) and value is sys.exception |
|
|
): |
|
|
return SysFunctionVariable(value, source=self.source) |
|
|
elif is_function_or_wrapper(value) and inspect.getattr_static( |
|
|
value, "_torchdynamo_inline", False |
|
|
): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return WrapperUserFunctionVariable( |
|
|
value, "_torchdynamo_inline", source=self.source |
|
|
) |
|
|
elif value is functools.wraps: |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return FunctoolsWrapsVariable(value, source=self.source) |
|
|
elif value is collections.namedtuple: |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return CollectionsNamedTupleFunction(value, source=self.source) |
|
|
elif isinstance( |
|
|
value, types.BuiltinMethodType |
|
|
) and BuiltinMethodVariable.is_supported_builtin_method(value): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return BuiltinMethodVariable(value, source=self.source) |
|
|
elif is_function(value) and value in (float.fromhex, float.hex): |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return GetAttrVariable( |
|
|
BuiltinVariable(float, source=self.source), |
|
|
value.__name__, |
|
|
) |
|
|
elif is_function_or_wrapper(value): |
|
|
value, attr_name = unwrap_with_attr_name_if_wrapper(value) |
|
|
|
|
|
|
|
|
if attr_name is not None: |
|
|
self.source = AttrSource(self.source, attr_name) |
|
|
return trace_rules.lookup(value).create_with_source( |
|
|
value, source=self.source |
|
|
) |
|
|
elif value is random.Random: |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return RandomClassVariable(source=self.source) |
|
|
elif istype(value, random.Random) and RandomVariable.is_supported_random_obj( |
|
|
value |
|
|
): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = RandomVariable(value, source=self.source) |
|
|
self.tx.output.side_effects.track_mutable(value, result) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
result = PythonModuleVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
self.tx.output.side_effects.track_object_existing(value, result) |
|
|
return result |
|
|
elif isinstance(value, types.MethodType) and isinstance( |
|
|
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_obj = VariableBuilder( |
|
|
self.tx, source=AttrSource(self.source, "__self__") |
|
|
)(value.__self__) |
|
|
assert self_obj and isinstance(self_obj, VariableTracker), ( |
|
|
"Failed to produce a valid self obj" |
|
|
) |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return UserMethodVariable( |
|
|
value.__func__, |
|
|
self_obj, |
|
|
source=self.source, |
|
|
) |
|
|
elif isinstance(value, types.GetSetDescriptorType): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return GetSetDescriptorVariable(value) |
|
|
elif isinstance(value, types.MethodWrapperType): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return MethodWrapperVariable(value) |
|
|
elif issubclass(type(value), type) and issubclass(value, BaseException): |
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
return UserDefinedExceptionClassVariable(value) |
|
|
elif issubclass(type(value), type): |
|
|
if value in ( |
|
|
torch.utils.hooks.BackwardHook, |
|
|
torch.nn.Parameter, |
|
|
torch.nn.Buffer, |
|
|
): |
|
|
|
|
|
return trace_rules.lookup(value).create_with_source( |
|
|
value, source=self.source |
|
|
) |
|
|
if value is torch.autograd._unsafe_preserve_version_counter: |
|
|
self.install_guards(GuardBuilder.FUNCTION_MATCH) |
|
|
return PreserveVersionContextVariable.constructor(self.tx) |
|
|
if ( |
|
|
|
|
|
issubclass(value, torch.Tensor) |
|
|
and value is not torch.Tensor |
|
|
|
|
|
|
|
|
and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
and not is_traceable_wrapper_subclass_type(value) |
|
|
): |
|
|
return TensorSubclassVariable(value, source=self.source) |
|
|
|
|
|
if not is_from_closure_source(self.source): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
|
|
|
return UserDefinedClassVariable( |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
elif TorchScriptObjectVariable.is_matching_cls(type(value)): |
|
|
from ..source import ( |
|
|
FlattenScriptObjectSource, |
|
|
ScriptObjectQualifiedNameSource, |
|
|
) |
|
|
|
|
|
if torch._library.fake_class_registry.tracing_with_real(value): |
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(value), |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proxy.node.meta["grapharg"] = GraphArg( |
|
|
self.source, value, False, None, False, value |
|
|
) |
|
|
return TorchScriptObjectVariable.create( |
|
|
proxy, |
|
|
value, |
|
|
source=self.source, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(value, "__obj_flatten__"): |
|
|
return self.wrap_user_defined(value) |
|
|
|
|
|
|
|
|
LazyVariableTracker.realize_all( |
|
|
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))( |
|
|
value._type().qualified_name() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
LazyVariableTracker.realize_all( |
|
|
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))( |
|
|
value.__obj_flatten__() |
|
|
) |
|
|
) |
|
|
|
|
|
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( |
|
|
self.tx.output.fake_mode, value |
|
|
) |
|
|
|
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(value), |
|
|
fake_script_obj, |
|
|
source=self.source, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proxy.node.meta["grapharg"] = GraphArg( |
|
|
self.source, value, False, None, False, fake_script_obj |
|
|
) |
|
|
return TorchScriptObjectVariable.create( |
|
|
proxy, |
|
|
fake_script_obj, |
|
|
source=self.source, |
|
|
) |
|
|
elif ( |
|
|
isinstance(value, (dict, collections.OrderedDict)) |
|
|
and type(value).__new__ is dict.__new__ |
|
|
): |
|
|
|
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
|
|
|
self.tx.output.guard_on_key_order.add(self.source) |
|
|
|
|
|
|
|
|
|
|
|
def build_key_value(i, k, v): |
|
|
base = self.get_source() |
|
|
source_key = ConstDictKeySource(base, i) |
|
|
key = LazyVariableTracker.create(k, source_key) |
|
|
|
|
|
source_value = DictSubclassGetItemSource(base, source_key) |
|
|
res_value = LazyVariableTracker.create(v, source_value) |
|
|
|
|
|
return key, res_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = dict( |
|
|
build_key_value(i, k, v) |
|
|
for i, (k, v) in enumerate(get_items_from_dict(value)) |
|
|
) |
|
|
|
|
|
dict_vt = ConstDictVariable( |
|
|
result, |
|
|
user_cls=( |
|
|
collections.OrderedDict |
|
|
if isinstance(value, collections.OrderedDict) |
|
|
else dict |
|
|
), |
|
|
mutation_type=ValueMutationExisting(), |
|
|
source=self.source, |
|
|
) |
|
|
|
|
|
|
|
|
dict_vt.should_reconstruct_all = True |
|
|
|
|
|
result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif isinstance(value, tuple): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
|
|
|
|
|
|
output = [ |
|
|
LazyVariableTracker.create( |
|
|
tuple.__getitem__(value, i), |
|
|
source=GetItemSource(self.get_source(), i), |
|
|
) |
|
|
for i in range(tuple.__len__(value)) |
|
|
] |
|
|
|
|
|
tuple_vt = TupleVariable( |
|
|
output, source=self.source, mutation_type=ValueMutationExisting() |
|
|
) |
|
|
result = UserDefinedTupleVariable( |
|
|
value, tuple_vt=tuple_vt, source=self.source |
|
|
) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif isinstance(value, list): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
|
|
|
|
|
|
output = [ |
|
|
LazyVariableTracker.create( |
|
|
list.__getitem__(value, i), |
|
|
source=ListGetItemSource(self.get_source(), i), |
|
|
) |
|
|
for i in range(list.__len__(value)) |
|
|
] |
|
|
list_vt = ListVariable( |
|
|
output, source=self.source, mutation_type=ValueMutationExisting() |
|
|
) |
|
|
result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif isinstance(value, (set, frozenset)): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
L = list(dict.fromkeys(value)) |
|
|
output = [ |
|
|
LazyVariableTracker.create( |
|
|
list.__getitem__(L, i), |
|
|
source=NonSerializableSetGetItemSource(self.get_source(), i), |
|
|
) |
|
|
for i in range(list.__len__(L)) |
|
|
] |
|
|
set_vt_cls = SetVariable if isinstance(value, set) else FrozensetVariable |
|
|
set_vt = set_vt_cls( |
|
|
output, source=self.source, mutation_type=ValueMutationExisting() |
|
|
) |
|
|
result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif issubclass(type(value), MutableMapping): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = MutableMappingVariable(value, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif is_frozen_dataclass(value): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = FrozenDataClassVariable.create(self.tx, value, source=self.source) |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif isinstance(value, dict_keys): |
|
|
if all(ConstantVariable.is_literal(k) for k in value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
items = [SourcelessBuilder.create(self.tx, v) for v in value] |
|
|
install_guard( |
|
|
self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH), |
|
|
self.get_source().make_guard(GuardBuilder.EQUALS_MATCH), |
|
|
) |
|
|
return DictKeySetVariable(items, source=self.source) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="non-const keys in dict_keys", |
|
|
context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", |
|
|
explanation="Dynamo expects dict_keys keys to be constants.", |
|
|
hints=[ |
|
|
"Ensure your dict_keys keys are constants (e.g. int, float, strings)", |
|
|
], |
|
|
) |
|
|
elif IntWrapperVariable.is_matching_object(value): |
|
|
from torch.export.dynamic_shapes import _DimHintType |
|
|
|
|
|
if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC: |
|
|
return self.wrap_symint(value.val) |
|
|
elif value.dynamism.type == _DimHintType.DYNAMIC: |
|
|
log.debug( |
|
|
"%s marked %s via IntWrapper", |
|
|
self.source.name(), |
|
|
DimDynamic.DYNAMIC, |
|
|
) |
|
|
return self.wrap_symint( |
|
|
value.val, |
|
|
dynamism=DimDynamic.DYNAMIC, |
|
|
context=SymIntSymbolicContext( |
|
|
constraint=RelaxedUnspecConstraint(warn_only=False) |
|
|
), |
|
|
) |
|
|
elif value.dynamism.type == _DimHintType.AUTO: |
|
|
log.debug( |
|
|
"%s marked %s via IntWrapper", |
|
|
self.source.name(), |
|
|
DimDynamic.DYNAMIC, |
|
|
) |
|
|
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC) |
|
|
else: |
|
|
raise RuntimeError(f"Undefined dynamism {value.dynamism}") |
|
|
else: |
|
|
return self.wrap_user_defined(value) |
|
|
|
|
|
def wrap_user_defined(self, value: Any): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = UserDefinedObjectVariable(value, source=self.source) |
|
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)): |
|
|
|
|
|
return result |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
|
|
|
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): |
|
|
for item in value: |
|
|
if item is value: |
|
|
unimplemented_v2( |
|
|
gb_type="list elements are pointing to the list itself", |
|
|
context="", |
|
|
explanation="Dynamo does not support lists whose items reference to itself", |
|
|
hints=["Avoid using self referential list"], |
|
|
) |
|
|
|
|
|
if config.specialize_int and type(value) is torch.Size: |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value) |
|
|
|
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
istype(value, tuple) |
|
|
and all(ConstantVariable.is_literal(item) for item in value) |
|
|
and self.source.guard_source().is_unspecialized_nn_module() |
|
|
): |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return TupleVariable([ConstantVariable.create(item) for item in value]) |
|
|
|
|
|
output = [ |
|
|
LazyVariableTracker.create( |
|
|
item, |
|
|
source=GetItemSource(self.get_source(), i), |
|
|
) |
|
|
for i, item in enumerate(value) |
|
|
] |
|
|
|
|
|
maybe_gm = self.tx.output.local_scope.get("self") |
|
|
if isinstance( |
|
|
self.source, LocalSource |
|
|
) and self.source.local_name in get_locals_to_steal(maybe_gm): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source = self.source |
|
|
assert isinstance(value, list) |
|
|
tensor_list_proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(value), |
|
|
value, |
|
|
source=source, |
|
|
) |
|
|
tensor_list_proxy.node.meta["steal_arg"] = True |
|
|
|
|
|
list_variable = wrap_fx_proxy_cls( |
|
|
target_cls=TensorVariable, |
|
|
tx=self.tx, |
|
|
proxy=tensor_list_proxy, |
|
|
example_value=value, |
|
|
subclass_type=None, |
|
|
source=source, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
guards = [] |
|
|
for i, tensor_variable in enumerate(list_variable.items): |
|
|
source_i = GetItemSource(base=source, index=i, index_is_slice=False) |
|
|
|
|
|
self.tx.output.input_source_to_var[source_i] = tensor_variable |
|
|
tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( |
|
|
value[i] |
|
|
) |
|
|
guard = functools.partial( |
|
|
GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) |
|
|
) |
|
|
guards.append(source_i.make_guard(guard)) |
|
|
|
|
|
install_guard(*guards, skip=1) |
|
|
|
|
|
grapharg = GraphArg( |
|
|
source, |
|
|
value, |
|
|
pass_arg_as_tensor=False, |
|
|
fake_tensor=None, |
|
|
is_tensor=False, |
|
|
) |
|
|
tensor_list_proxy.node.meta["grapharg"] = grapharg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for vt in output: |
|
|
vt.realize() |
|
|
|
|
|
result = BaseListVariable.cls_for_instance(value)(output, source=self.source) |
|
|
if istype(value, (list, collections.deque)): |
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
return result |
|
|
|
|
|
def wrap_tuple_iterator(self, value: tuple_iterator): |
|
|
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) |
|
|
output = [ |
|
|
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( |
|
|
tuple_iterator_getitem(value, i) |
|
|
) |
|
|
for i in range(tuple_iterator_len(value)) |
|
|
] |
|
|
result = TupleIteratorVariable(output, source=self.source) |
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
|
|
|
def wrap_range_iterator(self, value: range_iterator): |
|
|
self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH) |
|
|
|
|
|
|
|
|
items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] |
|
|
result = ListIteratorVariable(items, source=self.source) |
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
|
|
|
def wrap_slice_range(self, value: Union[slice, range]): |
|
|
items = [ |
|
|
VariableBuilder(self.tx, AttrSource(self.get_source(), k))( |
|
|
getattr(value, k) |
|
|
) |
|
|
for k in ("start", "stop", "step") |
|
|
] |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
if isinstance(value, slice): |
|
|
return SliceVariable(items, source=self.source) |
|
|
else: |
|
|
return RangeVariable(items, source=self.source) |
|
|
|
|
|
def mark_static_input(self, value: torch.Tensor, guard: bool): |
|
|
from ..decorators import mark_static_address |
|
|
|
|
|
static_inputs_log.debug( |
|
|
"Marking static input %s, id: %s)", self.source.name(), id(value) |
|
|
) |
|
|
mark_static_address(value, guard=guard) |
|
|
|
|
|
|
|
|
|
|
|
if value in self.tx.output.side_effects: |
|
|
var = self.tx.output.side_effects[value] |
|
|
var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = ( |
|
|
value._dynamo_static_input_type |
|
|
) |
|
|
|
|
|
def wrap_module(self, value: torch.nn.Module): |
|
|
from ..eval_frame import OptimizedModule |
|
|
|
|
|
if len(value.__dict__) == 0: |
|
|
unimplemented_v2( |
|
|
gb_type="Uninitialized nn.Module", |
|
|
context=typestr(value), |
|
|
explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", |
|
|
hints=[ |
|
|
*graph_break_hints.USER_ERROR, |
|
|
"Ensure your nn.Module instance has called `super().__init__()`.", |
|
|
], |
|
|
) |
|
|
if istype(value, OptimizedModule): |
|
|
|
|
|
if inspect.getattr_static(value.forward, "_torchdynamo_disable", False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
msg = inspect.getattr_static( |
|
|
value.forward, "_torchdynamo_disable_msg", None |
|
|
) |
|
|
return DelayGraphBreakVariable( |
|
|
source=self.source, |
|
|
msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", |
|
|
) |
|
|
|
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
self.source = AttrSource(self.source, "_orig_mod") |
|
|
return self.wrap_module(value._orig_mod) |
|
|
|
|
|
if ( |
|
|
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) |
|
|
and not config.allow_rnn |
|
|
): |
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap RNN, GRU, or LSTM", |
|
|
context=str(value), |
|
|
explanation="Dynamo does not support RNN, GRU, or LSTM.", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
|
|
|
if getattr(value, "_is_fsdp_managed_module", False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not getattr(value, "_fsdp_use_orig_params", False): |
|
|
unimplemented_v2( |
|
|
gb_type="FSDP with use_orig_params=False", |
|
|
context="", |
|
|
explanation="Dynamo only supports FSDP with use_orig_params=True", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
result = FSDPManagedNNModuleVariable(value, source=self.get_source()) |
|
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)): |
|
|
|
|
|
return result |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): |
|
|
value = value.get_base() |
|
|
self.source = AttrProxySource(self.source) |
|
|
|
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules: |
|
|
freezing = is_parameter_freezing() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
callable(value.named_parameters) |
|
|
and value.named_parameters.__func__ |
|
|
is og_module_named_parameters_fn_ptr |
|
|
): |
|
|
try: |
|
|
for _, p in value.named_parameters(): |
|
|
self.mark_static_input(p, guard=freezing) |
|
|
except TypeError as e: |
|
|
raise_observed_exception(type(e), self.tx, args=list(e.args)) |
|
|
|
|
|
if ( |
|
|
callable(value.named_buffers) |
|
|
and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr |
|
|
): |
|
|
try: |
|
|
for _, b in value.named_buffers(): |
|
|
self.mark_static_input(b, guard=freezing) |
|
|
except TypeError as e: |
|
|
raise_observed_exception(type(e), self.tx, args=list(e.args)) |
|
|
|
|
|
if freezing: |
|
|
|
|
|
|
|
|
|
|
|
self.tx.output.nn_modules[self.name] = value |
|
|
|
|
|
if ( |
|
|
value.__module__.startswith(("torch.nn.modules", "torch.ao.")) |
|
|
and not value.__module__.startswith("torch.nn.modules.container") |
|
|
) or getattr(value.__class__, "_dynamo_marked_static", False): |
|
|
new_source = self.source |
|
|
if config.inline_inbuilt_nn_modules and ( |
|
|
not self.tx.output.export or config.install_free_tensors |
|
|
): |
|
|
|
|
|
new_source = UnspecializedBuiltinNNModuleSource(self.source) |
|
|
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) |
|
|
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) |
|
|
else: |
|
|
new_source = self.source |
|
|
if config.inline_inbuilt_nn_modules and ( |
|
|
not self.tx.output.export or config.install_free_tensors |
|
|
): |
|
|
|
|
|
new_source = UnspecializedNNModuleSource(self.source) |
|
|
result = UnspecializedNNModuleVariable(value, source=new_source) |
|
|
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) |
|
|
|
|
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)): |
|
|
|
|
|
return result |
|
|
return self.tx.output.side_effects.track_object_existing(value, result) |
|
|
elif issubclass( |
|
|
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel |
|
|
): |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
return UnspecializedNNModuleVariable(value, source=self.get_source()) |
|
|
else: |
|
|
return self.tx.output.register_attr_or_module( |
|
|
value, |
|
|
self.name, |
|
|
source=self.get_source(), |
|
|
|
|
|
) |
|
|
|
|
|
def wrap_literal(self, value): |
|
|
if type(value) is int: |
|
|
|
|
|
if is_dynamic_source(self.source.name()): |
|
|
log.debug("%s marked dynamic via source whitelist", self.source.name()) |
|
|
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC) |
|
|
|
|
|
if is_unbacked_source(self.source.name()): |
|
|
log.debug("%s marked unbacked via source whitelist", self.source.name()) |
|
|
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED) |
|
|
|
|
|
if not config.specialize_int: |
|
|
|
|
|
|
|
|
if is_int_specialization_case(value, self.source): |
|
|
recompile_hint = None |
|
|
if ( |
|
|
self.source.guard_source().is_unspecialized_builtin_nn_module() |
|
|
or self.source.guard_source().is_unspecialized_nn_module() |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
recompile_hint = ( |
|
|
"torch.compile considers integer attributes of the nn.Module to be static. " |
|
|
"If you are observing recompilation, you might want to make this integer dynamic " |
|
|
"using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this " |
|
|
"integer into a tensor." |
|
|
) |
|
|
|
|
|
process_automatic_dynamic( |
|
|
self.tx, |
|
|
self.source.name(), |
|
|
FrameStateSizeEntry.make_scalar(value), |
|
|
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), |
|
|
) |
|
|
self.install_guards( |
|
|
functools.partial( |
|
|
GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint |
|
|
) |
|
|
) |
|
|
return ConstantVariable.create(value=value, source=self.source) |
|
|
|
|
|
return self.wrap_symint(value) |
|
|
elif not config.specialize_float and type(value) is float: |
|
|
return self.wrap_symfloat(value) |
|
|
else: |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
result = ConstantVariable.create(value=value, source=self.source) |
|
|
if isinstance(value, (list, set)): |
|
|
return self.tx.output.side_effects.track_mutable(value, result) |
|
|
return result |
|
|
|
|
|
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): |
|
|
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: |
|
|
raise InternalTorchDynamoError( |
|
|
"Cannot wrap a Tensor that has already been", |
|
|
"wrapped by this instance of Dynamo", |
|
|
) |
|
|
|
|
|
def wrap_tensor(self, value: torch.Tensor): |
|
|
source = self.get_source() |
|
|
|
|
|
|
|
|
|
|
|
assert value not in self.tx.output.side_effects |
|
|
|
|
|
is_static_input = get_static_address_type(value) is not None |
|
|
|
|
|
if ( |
|
|
config.inline_inbuilt_nn_modules |
|
|
and not is_static_input |
|
|
and ( |
|
|
isinstance(value, torch.nn.Parameter) |
|
|
|
|
|
|
|
|
or (source and source.guard_source().is_unspecialized_nn_module()) |
|
|
) |
|
|
): |
|
|
self.mark_static_input(value, guard=is_parameter_freezing()) |
|
|
is_static_input = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
should_install_free_tensor = config.install_free_tensors and ( |
|
|
is_from_global_source(source) |
|
|
or is_from_nonlocal_source(source) |
|
|
or is_from_unspecialized_nn_module_source(source) |
|
|
) |
|
|
|
|
|
make_graph_attribute = is_static_input and ( |
|
|
not config.inline_inbuilt_nn_modules |
|
|
or is_parameter_freezing() |
|
|
or torch._dynamo.config.prepare_freezing |
|
|
) |
|
|
|
|
|
if should_install_free_tensor or ( |
|
|
(source.guard_source().is_specialized_nn_module() or make_graph_attribute) |
|
|
and not source.guard_source().is_fsdp_module() |
|
|
): |
|
|
self.assert_not_wrapped_by_this_graph(value) |
|
|
return self.tx.output.register_attr_or_module( |
|
|
value, self.name, source=source |
|
|
) |
|
|
|
|
|
if get_static_address_type(value) == "guarded": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
self.assert_not_wrapped_by_this_graph(value) |
|
|
return self.tx.output.register_attr_or_module( |
|
|
value, self.name, source=source |
|
|
) |
|
|
|
|
|
if is_constant_source(source): |
|
|
self.assert_not_wrapped_by_this_graph(value) |
|
|
return self.tx.output.register_attr_or_module( |
|
|
value, |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
source=source, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_duplicate_tensor = source in self.tx.output.input_source_to_var |
|
|
if is_duplicate_tensor: |
|
|
return self.tx.output.input_source_to_var[source] |
|
|
|
|
|
options = {} |
|
|
subclass_type = infer_subclass_type(value) |
|
|
if subclass_type is not None: |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
|
|
|
if get_static_address_type(value) == "guarded": |
|
|
self.install_guards(GuardBuilder.ID_MATCH) |
|
|
|
|
|
|
|
|
self.assert_not_wrapped_by_this_graph(value) |
|
|
|
|
|
if ( |
|
|
isinstance(value, torch.Tensor) |
|
|
and value.is_nested |
|
|
and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) |
|
|
): |
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap strided NestedTensor", |
|
|
context="", |
|
|
explanation="torch.compile does not support strided NestedTensor", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
isinstance(value, torch.Tensor) |
|
|
and is_sparse_any(value) |
|
|
and (not self.tx.export or not config.capture_sparse_compute) |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to wrap sparse Tensor", |
|
|
context="", |
|
|
explanation="torch.compile does not support sparse Tensors", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
|
|
|
if ( |
|
|
safe_has_grad(value) |
|
|
and safe_grad(value) is not None |
|
|
and value.dtype != safe_grad(value).dtype |
|
|
): |
|
|
unimplemented_v2( |
|
|
gb_type="dtype mismatch between tensor and its gradient", |
|
|
context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", |
|
|
explanation="Inconsistent dtype between tensor and its gradient. " |
|
|
"This can happen in FSDP and crashes meta tensor creation.", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_value = wrap_to_fake_tensor_and_record( |
|
|
value, tx=self.tx, is_tensor=True, source=source |
|
|
) |
|
|
|
|
|
tensor_proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(value), |
|
|
example_value, |
|
|
source=source, |
|
|
) |
|
|
cache_real_value_when_export(self.tx, tensor_proxy, value) |
|
|
|
|
|
tensor_variable = wrap_fx_proxy( |
|
|
tx=self.tx, |
|
|
proxy=tensor_proxy, |
|
|
example_value=example_value, |
|
|
subclass_type=subclass_type, |
|
|
source=source, |
|
|
**options, |
|
|
) |
|
|
|
|
|
if value._is_view(): |
|
|
|
|
|
|
|
|
|
|
|
wrap_to_fake_tensor_and_record( |
|
|
value._base, |
|
|
tx=self.tx, |
|
|
source=AttrSource(source, "_base"), |
|
|
is_tensor=True, |
|
|
) |
|
|
|
|
|
guard_type = GuardBuilder.TENSOR_MATCH |
|
|
|
|
|
if isinstance(source, GradSource) and is_from_optimizer_source(source): |
|
|
guard_type = GuardBuilder.NOT_NONE_MATCH |
|
|
|
|
|
self.install_guards( |
|
|
functools.partial( |
|
|
guard_type, |
|
|
value=( |
|
|
value |
|
|
if isinstance(source, NumpyTensorSource) |
|
|
else TensorWeakRef(value) |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if is_traceable_wrapper_subclass(value): |
|
|
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) |
|
|
self.install_guards(GuardBuilder.TYPE_MATCH) |
|
|
install_guard( |
|
|
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) |
|
|
) |
|
|
|
|
|
attrs, _ = value.__tensor_flatten__() |
|
|
for attr in attrs: |
|
|
inner_value = getattr(value, attr) |
|
|
inner_source = AttrSource(self.source, attr) |
|
|
LazyVariableTracker.realize_all( |
|
|
VariableBuilder(self.tx, inner_source)(inner_value) |
|
|
) |
|
|
|
|
|
self.tx.output.input_source_to_var[source] = tensor_variable |
|
|
assert "tensor_dict" not in tensor_proxy.node.meta |
|
|
tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value) |
|
|
|
|
|
|
|
|
fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] |
|
|
if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: |
|
|
raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") |
|
|
|
|
|
grapharg = GraphArg(source, value, False, fake_tensor_value) |
|
|
tensor_proxy.node.meta["grapharg"] = grapharg |
|
|
return tensor_variable |
|
|
|
|
|
def wrap_numpy_ndarray(self, value): |
|
|
assert np is not None |
|
|
assert isinstance(value, np.ndarray) |
|
|
|
|
|
source = NumpyTensorSource(self.get_source()) |
|
|
|
|
|
from torch._numpy import _util |
|
|
|
|
|
readonly = not value.flags.writeable |
|
|
if readonly: |
|
|
try: |
|
|
value.flags.writeable = True |
|
|
except ValueError: |
|
|
|
|
|
|
|
|
assert isinstance(value.base, np.nditer) |
|
|
|
|
|
with torch_function_mode_stack_state_mgr.temp_restore_stack(): |
|
|
try: |
|
|
tensor_value = _util._try_convert_to_tensor(value) |
|
|
if readonly: |
|
|
from torch._prims_common import clone_preserve_strides |
|
|
|
|
|
tensor_value = clone_preserve_strides(tensor_value) |
|
|
except NotImplementedError as e: |
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="failed to convert numpy.ndarray to Tensor", |
|
|
context=str(value), |
|
|
explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor", |
|
|
hints=[], |
|
|
from_exc=e, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value)) |
|
|
example_value = wrap_to_fake_tensor_and_record( |
|
|
tensor_value, |
|
|
tx=self.tx, |
|
|
is_tensor=False, |
|
|
source=source, |
|
|
) |
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(tensor_value), |
|
|
example_value, |
|
|
source=source, |
|
|
) |
|
|
cache_real_value_when_export(self.tx, proxy, tensor_value) |
|
|
options = {"source": source} |
|
|
numpy_ndarray_variable = wrap_fx_proxy_cls( |
|
|
target_cls=NumpyNdarrayVariable, |
|
|
tx=self.tx, |
|
|
proxy=proxy, |
|
|
example_value=example_value, |
|
|
**options, |
|
|
) |
|
|
|
|
|
self.tx.output.input_source_to_var[source] = numpy_ndarray_variable |
|
|
example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] |
|
|
|
|
|
|
|
|
|
|
|
grapharg = GraphArg( |
|
|
source, |
|
|
tensor_value, |
|
|
pass_arg_as_tensor=True, |
|
|
fake_tensor=example_value, |
|
|
is_tensor=True, |
|
|
example_strong_ref=tensor_value, |
|
|
) |
|
|
proxy.node.meta["grapharg"] = grapharg |
|
|
|
|
|
|
|
|
|
|
|
numpy_ndarray_variable.source = self.source |
|
|
|
|
|
return numpy_ndarray_variable |
|
|
|
|
|
def wrap_symint( |
|
|
self, |
|
|
value, |
|
|
dynamism: Optional[DimDynamic] = None, |
|
|
context: Optional[SymIntSymbolicContext] = None, |
|
|
): |
|
|
assert type(value) is int |
|
|
|
|
|
if self.name in self.tx.output.unspec_variable_map: |
|
|
return self.tx.output.unspec_variable_map[self.name] |
|
|
|
|
|
shape_env = self.tx.output.shape_env |
|
|
if TracingContext.get().force_unspec_int_unbacked_size_like: |
|
|
wrapped_value = shape_env.create_unbacked_symint() |
|
|
_constrain_range_for_size(wrapped_value) |
|
|
self.tx.output.tracked_fakes.append( |
|
|
TrackedFake(wrapped_value, self.source, None) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif not is_constant_source(self.get_source()): |
|
|
if dynamism is None and torch._dynamo.config.specialize_int: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value, source=self.source) |
|
|
|
|
|
name = self.source.name() |
|
|
|
|
|
frame_state_entry = process_automatic_dynamic( |
|
|
self.tx, |
|
|
name, |
|
|
FrameStateSizeEntry.make_scalar(value), |
|
|
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
normalized_source_name = normalize_source_name(self.source.name()) |
|
|
base_source = self.source |
|
|
if isinstance(base_source, ChainedSource): |
|
|
base_source = base_source.get_base() |
|
|
|
|
|
if dynamism is not None: |
|
|
dynamic_dim = dynamism |
|
|
elif ( |
|
|
config.automatic_dynamic_shapes |
|
|
and frame_state_entry.scalar is auto_dynamic |
|
|
): |
|
|
set_feature_use("dynamo.automatic_dynamic_shapes", True) |
|
|
dynamic_dim = get_automatic_dynamic_shapes_mark_as() |
|
|
elif ( |
|
|
isinstance(base_source, LocalSource) |
|
|
and base_source.dynamism is not None |
|
|
and dict(base_source.dynamism).get(normalized_source_name, {0: False})[ |
|
|
0 |
|
|
] |
|
|
) or not config.assume_static_by_default: |
|
|
dynamic_dim = DimDynamic.DYNAMIC |
|
|
else: |
|
|
|
|
|
|
|
|
if frame_state_entry.scalar is auto_dynamic: |
|
|
set_feature_use("dynamo.automatic_dynamic_shapes", False) |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value) |
|
|
|
|
|
wrapped_value = shape_env.create_unspecified_symint_and_symbol( |
|
|
value, |
|
|
source=self.source, |
|
|
dynamic_dim=dynamic_dim, |
|
|
) |
|
|
|
|
|
self.tx.output.tracked_fakes.append( |
|
|
TrackedFake(wrapped_value, self.source, context) |
|
|
) |
|
|
else: |
|
|
assert is_constant_source(self.get_source()) |
|
|
|
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value, source=self.source) |
|
|
|
|
|
assert not isinstance(self.get_source(), RandomValueSource) |
|
|
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) |
|
|
|
|
|
options = {"source": self.get_source()} |
|
|
|
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(wrapped_value), |
|
|
wrapped_value, |
|
|
source=self.get_source(), |
|
|
) |
|
|
|
|
|
sym_expr = wrapped_value.node.expr |
|
|
assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol." |
|
|
self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy |
|
|
unspec_var = SymNodeVariable(proxy, wrapped_value, **options) |
|
|
self.tx.output.unspec_variable_map[self.name] = unspec_var |
|
|
|
|
|
if not is_constant_source(self.get_source()): |
|
|
proxy.node.meta["grapharg"] = GraphArg( |
|
|
self.get_source(), |
|
|
wrapped_value, |
|
|
pass_arg_as_tensor=False, |
|
|
fake_tensor=None, |
|
|
is_tensor=False, |
|
|
example_strong_ref=wrapped_value, |
|
|
) |
|
|
|
|
|
return unspec_var |
|
|
|
|
|
def wrap_symfloat(self, value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.name in self.tx.output.unspec_variable_map: |
|
|
return self.tx.output.unspec_variable_map[self.name] |
|
|
|
|
|
frame_state_entry = process_automatic_dynamic( |
|
|
self.tx, |
|
|
self.source.name(), |
|
|
FrameStateSizeEntry.make_scalar(value), |
|
|
is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
torch._dynamo.config.specialize_float |
|
|
or is_constant_source(self.get_source()) |
|
|
or math.isnan(value) |
|
|
or math.isinf(value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
or torch._inductor.config.triton.cudagraphs |
|
|
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) |
|
|
or ( |
|
|
config.assume_static_by_default |
|
|
and frame_state_entry.scalar is not auto_dynamic |
|
|
) |
|
|
): |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value, source=self.source) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wrapped_value = torch.tensor(value, dtype=torch.float64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch._C._functorch.is_gradtrackingtensor(wrapped_value): |
|
|
self.install_guards(GuardBuilder.CONSTANT_MATCH) |
|
|
return ConstantVariable.create(value=value, source=self.source) |
|
|
|
|
|
|
|
|
|
|
|
assert not isinstance(self.get_source(), RandomValueSource) |
|
|
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source = FloatTensorSource(self.get_source()) |
|
|
options = {"source": source, "raw_value": value} |
|
|
|
|
|
|
|
|
|
|
|
example_value = wrap_to_fake_tensor_and_record( |
|
|
wrapped_value, tx=self.tx, is_tensor=False, source=source |
|
|
) |
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(wrapped_value), |
|
|
example_value, |
|
|
source=source, |
|
|
) |
|
|
cache_real_value_when_export(self.tx, proxy, wrapped_value) |
|
|
|
|
|
unspec_var = wrap_fx_proxy_cls( |
|
|
UnspecializedPythonVariable, |
|
|
tx=self.tx, |
|
|
proxy=proxy, |
|
|
example_value=example_value, |
|
|
**options, |
|
|
) |
|
|
assert isinstance(unspec_var, UnspecializedPythonVariable) |
|
|
self.tx.output.unspec_variable_map[self.name] = unspec_var |
|
|
|
|
|
if self.tx.export and not isinstance(self.get_source(), LocalSource): |
|
|
raise AssertionError( |
|
|
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" |
|
|
) |
|
|
fake_tensor_value = None |
|
|
example_value = unspec_var.proxy.node.meta["example_value"] |
|
|
assert is_fake(example_value) |
|
|
|
|
|
fake_tensor_value = example_value |
|
|
assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( |
|
|
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" |
|
|
"({self.tx.fake_mode}) from InstructionTranslator" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proxy.node.meta["grapharg"] = GraphArg( |
|
|
self.get_source(), |
|
|
wrapped_value, |
|
|
pass_arg_as_tensor=True, |
|
|
fake_tensor=fake_tensor_value, |
|
|
is_tensor=False, |
|
|
example_strong_ref=wrapped_value, |
|
|
) |
|
|
|
|
|
|
|
|
r = wrap_fx_proxy( |
|
|
self.tx, |
|
|
self.tx.output.create_proxy( |
|
|
"call_method", |
|
|
"item", |
|
|
*proxy_args_kwargs([unspec_var], {}), |
|
|
), |
|
|
) |
|
|
self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None)) |
|
|
|
|
|
get_metrics_context().set("tensorify_float_attempt", True, overwrite=True) |
|
|
|
|
|
return r |
|
|
|
|
|
def wrap_unspecialized_primitive(self, value): |
|
|
if self.name in self.tx.output.unspec_variable_map: |
|
|
return self.tx.output.unspec_variable_map[self.name] |
|
|
|
|
|
wrapped_value = torch.tensor(value) |
|
|
if not isinstance(self.get_source(), RandomValueSource): |
|
|
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) |
|
|
|
|
|
options = {"source": self.get_source()} |
|
|
options.update({"raw_value": value}) |
|
|
|
|
|
example_value = wrap_to_fake_tensor_and_record( |
|
|
wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source() |
|
|
) |
|
|
proxy = self.tx.output.root_tracer.create_graph_input( |
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), |
|
|
type(wrapped_value), |
|
|
example_value, |
|
|
source=self.get_source(), |
|
|
) |
|
|
cache_real_value_when_export(self.tx, proxy, wrapped_value) |
|
|
|
|
|
unspec_var = wrap_fx_proxy_cls( |
|
|
UnspecializedPythonVariable, |
|
|
tx=self.tx, |
|
|
proxy=proxy, |
|
|
example_value=example_value, |
|
|
**options, |
|
|
) |
|
|
self.tx.output.unspec_variable_map[self.name] = unspec_var |
|
|
if not is_constant_source(self.get_source()): |
|
|
if self.tx.export and not isinstance(self.get_source(), LocalSource): |
|
|
raise AssertionError( |
|
|
f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" |
|
|
) |
|
|
fake_tensor_value = None |
|
|
if isinstance(unspec_var, ConstantVariable): |
|
|
|
|
|
example_value = unspec_var.value |
|
|
else: |
|
|
example_value = unspec_var.proxy.node.meta["example_value"] |
|
|
assert is_fake(example_value) |
|
|
|
|
|
fake_tensor_value = example_value |
|
|
assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( |
|
|
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" |
|
|
"({self.tx.fake_mode}) from InstructionTranslator" |
|
|
) |
|
|
|
|
|
proxy.node.meta["grapharg"] = GraphArg( |
|
|
self.get_source(), |
|
|
wrapped_value, |
|
|
pass_arg_as_tensor=True, |
|
|
fake_tensor=fake_tensor_value, |
|
|
is_tensor=False, |
|
|
example_strong_ref=wrapped_value, |
|
|
) |
|
|
return unspec_var |
|
|
|
|
|
|
|
|
def _dataclasses_fields_lambda(obj): |
|
|
if isinstance(obj, UserDefinedObjectVariable): |
|
|
value = obj.value |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="dataclass fields failure", |
|
|
context=f"obj: {obj}; variable type: {type(obj)}", |
|
|
explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", |
|
|
hints=[], |
|
|
) |
|
|
items = [] |
|
|
for field in dataclasses.fields(value): |
|
|
source = None |
|
|
if obj.source: |
|
|
base_src = AttrSource(obj.source, "__dataclass_fields__") |
|
|
source = DictGetItemSource(base_src, field.name) |
|
|
items.append(UserDefinedObjectVariable(field, source=source)) |
|
|
return TupleVariable(items) |
|
|
|
|
|
|
|
|
def _clone_input(value, fake_mode): |
|
|
if isinstance(value, torch.Tensor): |
|
|
|
|
|
if not ( |
|
|
isinstance(value, FakeTensor) |
|
|
or ( |
|
|
|
|
|
torch._is_functional_tensor(value) |
|
|
and maybe_get_fake_mode(value) is fake_mode |
|
|
) |
|
|
or value.is_nested |
|
|
): |
|
|
|
|
|
value = clone_input(value) |
|
|
|
|
|
return value |
|
|
|
|
|
|
|
|
def wrap_fx_proxy( |
|
|
tx, proxy, example_value=None, subclass_type=None, **options |
|
|
) -> VariableTracker: |
|
|
kwargs = { |
|
|
"tx": tx, |
|
|
"proxy": proxy, |
|
|
"example_value": example_value, |
|
|
"subclass_type": subclass_type, |
|
|
**options, |
|
|
} |
|
|
if subclass_type is None: |
|
|
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) |
|
|
else: |
|
|
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) |
|
|
result.install_global(tx) |
|
|
return result |
|
|
|
|
|
|
|
|
def cache_real_value_when_export(tx, proxy, example_value): |
|
|
if tx.export: |
|
|
|
|
|
|
|
|
|
|
|
with torch._C.DisableTorchFunctionSubclass(): |
|
|
proxy.tracer.real_value_cache[proxy.node] = _clone_input( |
|
|
example_value, tx.fake_mode |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_fx_proxy_cls( |
|
|
target_cls, tx, proxy, example_value=None, subclass_type=None, **options |
|
|
): |
|
|
if example_value is None: |
|
|
return _wrap_fx_proxy( |
|
|
target_cls, tx, proxy, example_value, subclass_type, **options |
|
|
) |
|
|
elif isinstance(example_value, torch.Tensor): |
|
|
return _wrap_fx_preexisting_tensor( |
|
|
target_cls, tx, proxy, example_value, subclass_type, **options |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
return handle_traced_output( |
|
|
example_value, tx, proxy, options, subclass_type, target_cls |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_fx_preexisting_tensor( |
|
|
target_cls, tx, proxy, tensor, subclass_type=None, **options |
|
|
): |
|
|
from ..symbolic_convert import InstructionTranslatorBase |
|
|
|
|
|
assert isinstance(tensor, torch.Tensor), ( |
|
|
f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}" |
|
|
) |
|
|
|
|
|
assert isinstance(tx, InstructionTranslatorBase) |
|
|
if "guards" in options and options["guards"] is not None: |
|
|
tx.output.guards.update(options["guards"]) |
|
|
|
|
|
|
|
|
|
|
|
if proxy.node.op == "placeholder": |
|
|
assert "example_value" in proxy.node.meta, ( |
|
|
f"placeholder {proxy} doesn't have 'example_value' in node.meta" |
|
|
) |
|
|
else: |
|
|
assert "example_value" not in proxy.node.meta, ( |
|
|
f"{proxy.node.meta['example_value']}" |
|
|
) |
|
|
|
|
|
|
|
|
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): |
|
|
|
|
|
if maybe_get_fake_mode(tensor) is tx.fake_mode: |
|
|
pass |
|
|
else: |
|
|
cache_real_value_when_export(tx, proxy, tensor) |
|
|
if tx.export: |
|
|
|
|
|
|
|
|
|
|
|
with torch._C.DisableTorchFunctionSubclass(): |
|
|
proxy.tracer.real_value_cache[proxy.node] = _clone_input( |
|
|
tensor, tx.fake_mode |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kwargs = { |
|
|
"is_tensor": target_cls |
|
|
in (TensorVariable, TensorWithTFOverrideVariable), |
|
|
} |
|
|
assert "source" in options and options["source"] is not None |
|
|
kwargs["source"] = options["source"] |
|
|
tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs) |
|
|
|
|
|
if tensor.device.type != "meta" and ( |
|
|
maybe_get_fake_mode(tensor) is not tx.fake_mode |
|
|
): |
|
|
raise InternalTorchDynamoError( |
|
|
"`tensor` needs to be a `FakeTensor`" |
|
|
f"wrapped by this instance of Dynamo. Found: {tensor}" |
|
|
) |
|
|
|
|
|
return construct_tensor_variable( |
|
|
target_cls, tx, proxy, tensor, subclass_type, options |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_fx_proxy( |
|
|
target_cls, tx, proxy, example_value=None, subclass_type=None, **options |
|
|
): |
|
|
from ..symbolic_convert import InstructionTranslatorBase |
|
|
|
|
|
assert isinstance(tx, InstructionTranslatorBase) |
|
|
if "guards" in options and options["guards"] is not None: |
|
|
tx.output.guards.update(options["guards"]) |
|
|
|
|
|
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" |
|
|
|
|
|
|
|
|
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): |
|
|
|
|
|
|
|
|
|
|
|
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) |
|
|
|
|
|
return handle_traced_output( |
|
|
example_value, tx, proxy, options, subclass_type, target_cls |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls): |
|
|
import torch._functorch.vmap |
|
|
import torch._subclasses.fake_tensor |
|
|
import torch._utils |
|
|
|
|
|
if isinstance(example_value, torch.Tensor): |
|
|
var = construct_tensor_variable( |
|
|
target_cls, tx, proxy, example_value, subclass_type, options |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tx.output.side_effects._track_obj( |
|
|
proxy, var, mutation_type_cls=AttributeMutationNew |
|
|
) |
|
|
return var |
|
|
elif ( |
|
|
hasattr(proxy.node.target, "__name__") |
|
|
and proxy.node.target.__name__ == "set_state" |
|
|
and isinstance(proxy.node.target.__self__, torch._C.Generator) |
|
|
or proxy.node.target == torch.random.set_rng_state |
|
|
): |
|
|
return TorchInGraphFunctionVariable(proxy.node.target) |
|
|
elif ( |
|
|
proxy.node.target == torch._C._DisableFuncTorch |
|
|
or proxy.node.target == torch.cuda._is_in_bad_fork |
|
|
): |
|
|
return UserDefinedObjectVariable(example_value) |
|
|
elif istype(example_value, torch.Size) and all( |
|
|
isinstance(x, int) for x in example_value |
|
|
): |
|
|
sizes = [ConstantVariable.create(x) for x in example_value] |
|
|
return SizeVariable(sizes, **options) |
|
|
elif isinstance(example_value, (tuple, list)): |
|
|
set_example_value(proxy.node, example_value) |
|
|
unpacked = [] |
|
|
for i, val in enumerate(example_value): |
|
|
if val is None: |
|
|
|
|
|
unpacked.append( |
|
|
ConstantVariable.create(None, **options), |
|
|
) |
|
|
else: |
|
|
proxy_i = proxy.tracer.create_proxy( |
|
|
kind="call_function", |
|
|
target=operator.getitem, |
|
|
args=(proxy, i), |
|
|
kwargs={}, |
|
|
) |
|
|
|
|
|
if "source" in options: |
|
|
|
|
|
|
|
|
assert isinstance(example_value, list) |
|
|
source = options["source"] |
|
|
options_i = options.copy() |
|
|
options_i["source"] = GetItemSource( |
|
|
base=source, index=i, index_is_slice=False |
|
|
) |
|
|
else: |
|
|
|
|
|
options_i = options |
|
|
|
|
|
|
|
|
unpacked.append( |
|
|
wrap_fx_proxy_cls( |
|
|
target_cls=target_cls, |
|
|
tx=tx, |
|
|
proxy=proxy_i, |
|
|
example_value=val, |
|
|
**options_i, |
|
|
) |
|
|
) |
|
|
if isinstance(example_value, torch.Size): |
|
|
|
|
|
|
|
|
return SizeVariable(unpacked, proxy, **options) |
|
|
elif istype(example_value, tuple): |
|
|
return TupleVariable(unpacked, **options) |
|
|
elif istype(example_value, (list, immutable_list)): |
|
|
return ListVariable(unpacked, **options) |
|
|
else: |
|
|
assert ( |
|
|
example_value.__class__.__module__ == "torch.return_types" |
|
|
or hasattr(example_value, "_fields") |
|
|
), ( |
|
|
f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" |
|
|
) |
|
|
return NamedTupleVariable(unpacked, example_value.__class__, **options) |
|
|
elif example_value is None or proxy.node.target is torch.manual_seed: |
|
|
return ConstantVariable.create(None, **options) |
|
|
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): |
|
|
tx.output.current_tracer.track_produced_symints(example_value, proxy) |
|
|
set_example_value(proxy.node, example_value) |
|
|
return SymNodeVariable(proxy, example_value, **options) |
|
|
elif ( |
|
|
inspect.isclass(proxy.node.target) |
|
|
and issubclass(proxy.node.target, torch.Stream) |
|
|
) or proxy.node.target in [ |
|
|
device_interface.current_stream |
|
|
for _, device_interface in get_registered_device_interfaces() |
|
|
]: |
|
|
set_example_value(proxy.node, example_value) |
|
|
return StreamVariable(proxy, example_value, example_value.device, **options) |
|
|
elif ( |
|
|
inspect.isclass(proxy.node.target) |
|
|
and issubclass(proxy.node.target, torch.Event) |
|
|
) or proxy.node.target in [ |
|
|
device_interface.Event |
|
|
for _, device_interface in get_registered_device_interfaces() |
|
|
]: |
|
|
set_example_value(proxy.node, example_value) |
|
|
return EventVariable(proxy, example_value, **options) |
|
|
elif proxy.node.target == "query" and proxy.node.op == "call_method": |
|
|
set_example_value(proxy.node, example_value) |
|
|
return ConstantVariable(example_value, **options) |
|
|
elif ( |
|
|
example_value is not None |
|
|
and isinstance(example_value, torch.Event) |
|
|
and proxy.node.target == "record_event" |
|
|
and proxy.node.op == "call_method" |
|
|
): |
|
|
set_example_value(proxy.node, example_value) |
|
|
return EventVariable(proxy, example_value, **options) |
|
|
elif isinstance(example_value, int) and ( |
|
|
proxy.node.target |
|
|
in [ |
|
|
torch.sym_int, |
|
|
getattr, |
|
|
operator.getitem, |
|
|
torch._utils._element_size, |
|
|
torch.seed, |
|
|
operator.mod, |
|
|
torch._functorch.vmap._validate_and_get_batch_size, |
|
|
torch._functorch.predispatch._vmap_increment_nesting, |
|
|
torch._functorch.predispatch._vmap_decrement_nesting, |
|
|
|
|
|
getattr(torch.distributed, "get_rank", _missing), |
|
|
getattr(torch.distributed, "get_world_size", _missing), |
|
|
|
|
|
|
|
|
torch._constrain_as_size, |
|
|
] |
|
|
or ( |
|
|
|
|
|
proxy.node.op == "call_method" and proxy.node.target in ["bit_length"] |
|
|
) |
|
|
): |
|
|
set_example_value(proxy.node, example_value) |
|
|
return ConstantVariable.create(example_value, **options) |
|
|
elif isinstance(example_value, torch.backends.cuda.SDPAParams): |
|
|
from .sdpa import SDPAParamsVariable |
|
|
|
|
|
set_example_value(proxy.node, example_value) |
|
|
return SDPAParamsVariable(proxy, **options) |
|
|
elif isinstance(example_value, bool) and ( |
|
|
proxy.node.target |
|
|
in [ |
|
|
torch._C._are_functorch_transforms_active, |
|
|
torch._C._functorch.is_batchedtensor, |
|
|
torch.backends.cuda.is_flash_attention_available, |
|
|
torch.backends.cuda.can_use_flash_attention, |
|
|
torch.backends.cuda.can_use_efficient_attention, |
|
|
"is_integer", |
|
|
] |
|
|
+ list(supported_const_comparison_op_values.keys()) |
|
|
): |
|
|
set_example_value(proxy.node, example_value) |
|
|
return ConstantVariable.create(example_value, **options) |
|
|
elif isinstance(example_value, (int, float, bool)) and ( |
|
|
proxy.node.target is call_torchbind |
|
|
or proxy.node.target is flat_apply |
|
|
or (proxy.node.op == "call_method" and proxy.node.target == "item") |
|
|
): |
|
|
set_example_value(proxy.node, example_value) |
|
|
return ConstantVariable.create(example_value, **options) |
|
|
elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: |
|
|
set_example_value(proxy.node, example_value) |
|
|
return ConstantVariable.create(example_value, **options) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="torch.* op returned non-Tensor", |
|
|
context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", |
|
|
explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
|
|
|
def infer_subclass_type(value): |
|
|
if type(value) in ( |
|
|
torch.Tensor, |
|
|
torch.nn.Parameter, |
|
|
torch._subclasses.fake_tensor.FakeTensor, |
|
|
torch._subclasses.functional_tensor.FunctionalTensor, |
|
|
) or is_traceable_wrapper_subclass(value): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None |
|
|
else: |
|
|
return type(value) |
|
|
|
|
|
|
|
|
def get_specialized_props(target_cls, tx, example_value, subclass_type): |
|
|
specialized_props = target_cls.specialize(example_value) |
|
|
|
|
|
if ( |
|
|
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) |
|
|
and example_value.fake_mode is tx.fake_mode |
|
|
): |
|
|
if subclass_type: |
|
|
tensor_type = subclass_type |
|
|
elif isinstance(example_value, torch.nn.Parameter): |
|
|
tensor_type = torch.nn.Parameter |
|
|
elif isinstance(example_value, torch.nn.Buffer): |
|
|
tensor_type = torch.nn.Buffer |
|
|
else: |
|
|
tensor_type = torch.Tensor |
|
|
specialized_props["class_type"] = tensor_type |
|
|
|
|
|
return specialized_props |
|
|
|
|
|
|
|
|
def construct_tensor_variable( |
|
|
target_cls, tx, proxy, example_value, subclass_type, options |
|
|
): |
|
|
""" |
|
|
Actually construct a tensor variable after all the pre-processing from |
|
|
wrapping a pre-existing or newly created tensor value. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
example_value = _clone_input(example_value, tx.fake_mode) |
|
|
set_example_value(proxy.node, example_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if proxy.node.op != "placeholder": |
|
|
tx.output.current_tracer.track_produced_symints(example_value, proxy) |
|
|
options.update(get_specialized_props(target_cls, tx, example_value, subclass_type)) |
|
|
return target_cls(proxy, **options) |
|
|
|
|
|
|
|
|
def get_automatic_dynamic_shapes_mark_as(): |
|
|
if config.automatic_dynamic_shapes_mark_as == "dynamic": |
|
|
return DimDynamic.DYNAMIC |
|
|
elif config.automatic_dynamic_shapes_mark_as == "unbacked": |
|
|
return DimDynamic.SIZE_LIKE_UNBACKED |
|
|
elif config.automatic_dynamic_shapes_mark_as == "oblivious": |
|
|
return DimDynamic.OBLIVIOUS_SIZE |
|
|
else: |
|
|
raise ValueError( |
|
|
f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}" |
|
|
) |
|
|
|
|
|
|
|
|
_DYNAMIC_SOURCES: Optional[set[str]] = None |
|
|
_DYNAMIC_SOURCES_CONFIG_HASH: Optional[int] = None |
|
|
|
|
|
|
|
|
def get_dynamic_sources() -> set[str]: |
|
|
global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH |
|
|
|
|
|
current_hash = hash(torch.compiler.config.dynamic_sources) |
|
|
|
|
|
|
|
|
if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash: |
|
|
return _DYNAMIC_SOURCES |
|
|
|
|
|
|
|
|
_DYNAMIC_SOURCES = { |
|
|
s |
|
|
for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",") |
|
|
if s |
|
|
} |
|
|
_DYNAMIC_SOURCES_CONFIG_HASH = current_hash |
|
|
|
|
|
return _DYNAMIC_SOURCES |
|
|
|
|
|
|
|
|
def is_dynamic_source(source_name: str) -> bool: |
|
|
dynamic_sources = get_dynamic_sources() |
|
|
for pattern in dynamic_sources: |
|
|
if pattern == source_name or re.match(pattern, source_name): |
|
|
log.debug( |
|
|
"%s was marked dynamic due to dynamic source allowlist pattern: %s", |
|
|
source_name, |
|
|
pattern, |
|
|
) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def record_automatic_dynamic( |
|
|
tx: "InstructionTranslator", name: str, e: torch.Tensor |
|
|
) -> FrameStateSizeEntry: |
|
|
|
|
|
ex_size = e.size() |
|
|
if not is_sparse_any(e): |
|
|
ex_stride = e.stride() |
|
|
dim = e.dim() |
|
|
|
|
|
stride = [None] * dim |
|
|
pending = [(ex_stride[i], -i) for i in range(dim)] |
|
|
pending.sort(key=_nested_int_aware_sort) |
|
|
candidates = {} |
|
|
for i_stride, neg_i in pending: |
|
|
i = -neg_i |
|
|
stride[i] = candidates.get(i_stride, i_stride) |
|
|
candidates.setdefault(i_stride * ex_size[i], InferStride(i)) |
|
|
else: |
|
|
stride = [] |
|
|
|
|
|
return process_automatic_dynamic( |
|
|
tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)) |
|
|
) |
|
|
|
|
|
|
|
|
_UNBACKED_SOURCES: Optional[set[str]] = None |
|
|
_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None |
|
|
|
|
|
|
|
|
def get_unbacked_sources() -> set[str]: |
|
|
global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH |
|
|
|
|
|
current_hash = hash(torch.compiler.config.unbacked_sources) |
|
|
|
|
|
|
|
|
if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash: |
|
|
return _UNBACKED_SOURCES |
|
|
|
|
|
|
|
|
_UNBACKED_SOURCES = { |
|
|
s |
|
|
for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",") |
|
|
if s |
|
|
} |
|
|
_UNBACKED_SOURCES_CONFIG_HASH = current_hash |
|
|
|
|
|
return _UNBACKED_SOURCES |
|
|
|
|
|
|
|
|
def is_unbacked_source(source_name: str) -> bool: |
|
|
unbacked_sources = get_unbacked_sources() |
|
|
for pattern in unbacked_sources: |
|
|
if pattern == source_name or re.match(pattern, source_name): |
|
|
log.debug( |
|
|
"%s was marked unbacked due to unbacked source allowlist pattern: %s", |
|
|
source_name, |
|
|
pattern, |
|
|
) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _automatic_dynamic( |
|
|
e, tx, source, static_shapes, outer_only=False |
|
|
) -> SymbolicContext: |
|
|
|
|
|
if e.is_nested and not isinstance( |
|
|
e, torch.nested._internal.nested_tensor.NestedTensor |
|
|
): |
|
|
unimplemented_v2( |
|
|
gb_type="Encountered strided NestedTensor in automatic dynamic dim determination", |
|
|
context="", |
|
|
explanation="torch.compile does not support strided NestedTensor", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
name = source.name() |
|
|
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) |
|
|
shape_env_to_source_to_symbol_cache = ( |
|
|
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None |
|
|
) |
|
|
|
|
|
|
|
|
view_base_context: Optional[SymbolicContext] = None |
|
|
if e._is_view(): |
|
|
base_source = AttrSource(source, "_base") |
|
|
view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) |
|
|
|
|
|
if is_traceable_wrapper_subclass(e) and not outer_only: |
|
|
|
|
|
outer_context = _automatic_dynamic( |
|
|
e, tx, source, static_shapes, outer_only=True |
|
|
) |
|
|
|
|
|
|
|
|
inner_contexts = {} |
|
|
attrs, _ = type(e).__tensor_flatten__(e) |
|
|
for attr in attrs: |
|
|
inner_tensor = getattr(e, attr) |
|
|
inner_source = AttrSource(source, attr) |
|
|
inner_contexts[attr] = _automatic_dynamic( |
|
|
inner_tensor, tx, inner_source, static_shapes |
|
|
) |
|
|
|
|
|
return SubclassSymbolicContext( |
|
|
dynamic_sizes=outer_context.dynamic_sizes, |
|
|
dynamic_strides=outer_context.dynamic_strides, |
|
|
constraint_sizes=outer_context.constraint_sizes, |
|
|
constraint_strides=outer_context.constraint_strides, |
|
|
view_base_context=view_base_context, |
|
|
tensor_source=outer_context.tensor_source, |
|
|
shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, |
|
|
inner_contexts=inner_contexts, |
|
|
) |
|
|
|
|
|
if static_shapes and not is_dynamic_source(name): |
|
|
return StatefulSymbolicContext( |
|
|
dynamic_sizes=[DimDynamic.STATIC] * e.dim(), |
|
|
dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), |
|
|
constraint_sizes=[None] * e.dim(), |
|
|
constraint_strides=[None] * e.dim(), |
|
|
view_base_context=view_base_context, |
|
|
tensor_source=source, |
|
|
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import is_nested_int |
|
|
|
|
|
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): |
|
|
return StatefulSymbolicContext( |
|
|
dynamic_sizes=[ |
|
|
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC |
|
|
for s in e.size() |
|
|
], |
|
|
dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), |
|
|
constraint_sizes=[None] * e.dim(), |
|
|
constraint_strides=[None] * e.dim(), |
|
|
view_base_context=view_base_context, |
|
|
tensor_source=source, |
|
|
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, |
|
|
) |
|
|
|
|
|
|
|
|
frame_state_entry = record_automatic_dynamic(tx, name, e) |
|
|
|
|
|
|
|
|
|
|
|
t_id = id(e) |
|
|
dim2constraint = {} |
|
|
|
|
|
def update_dim2constraint(dim, constraint_range, name): |
|
|
if dim in dim2constraint: |
|
|
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
|
|
|
|
|
old_constraint_range, old_name = dim2constraint[dim] |
|
|
new_constraint_range = StrictMinMaxConstraint( |
|
|
vr=constraint_range.vr & old_constraint_range.vr, |
|
|
warn_only=False, |
|
|
) |
|
|
|
|
|
|
|
|
new_name = old_name or name |
|
|
dim2constraint[dim] = new_constraint_range, new_name |
|
|
else: |
|
|
dim2constraint[dim] = constraint_range, name |
|
|
|
|
|
from torch.export.dynamic_shapes import _RelaxedConstraint |
|
|
|
|
|
if tx.output.export_constraints: |
|
|
for constraint in tx.output.export_constraints: |
|
|
if isinstance(constraint, _RelaxedConstraint): |
|
|
continue |
|
|
if constraint.t_id == t_id: |
|
|
update_dim2constraint( |
|
|
constraint.dim, constraint.constraint_range, constraint.name |
|
|
) |
|
|
|
|
|
dynamic_sizes = [] |
|
|
dynamic_strides = [] |
|
|
constraint_sizes = [] |
|
|
constraint_strides = [] |
|
|
specialize_on = [] |
|
|
for i in range(e.dim()): |
|
|
|
|
|
marked_strict_unbacked = i in getattr( |
|
|
e, "_dynamo_strict_unbacked_indices", set() |
|
|
) |
|
|
marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) |
|
|
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) |
|
|
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) |
|
|
marked_static = i in getattr(e, "_dynamo_static_indices", set()) |
|
|
|
|
|
specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
normalized_source_name = normalize_source_name(source.name()) |
|
|
base_source = source |
|
|
if isinstance(base_source, ChainedSource): |
|
|
base_source = base_source.get_base() |
|
|
|
|
|
if marked_dynamic or ( |
|
|
isinstance(base_source, LocalSource) |
|
|
and base_source.dynamism is not None |
|
|
and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i] |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log.debug("automatic dynamic %s marked dynamic", name) |
|
|
mark_size = [auto_unset] * e.dim() |
|
|
mark_size[i] = auto_dynamic |
|
|
frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size) |
|
|
|
|
|
|
|
|
automatic_dynamic_size = ( |
|
|
config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
automatic_dynamic_stride = ( |
|
|
config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i) |
|
|
) |
|
|
|
|
|
if is_dynamic_source(name): |
|
|
log.debug("%s marked dynamic via source whitelist", name) |
|
|
automatic_dynamic_size = True |
|
|
|
|
|
if is_unbacked_source(name): |
|
|
log.debug("%s marked unbacked via source whitelist", name) |
|
|
automatic_dynamic_size = True |
|
|
|
|
|
automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constraint = dim2constraint.get(i) |
|
|
if constraint is None: |
|
|
constraint_size = None |
|
|
constraint_stride = None |
|
|
if marked_dynamic and not config.allow_ignore_mark_dynamic: |
|
|
|
|
|
constraint_stride = None |
|
|
if hasattr(e, "_dynamo_dynamic_range"): |
|
|
dim_range = [ |
|
|
dr for dr in e._dynamo_dynamic_range if dr.dim == i |
|
|
].pop() |
|
|
if dim_range.min is None and dim_range.max is None: |
|
|
constraint_size = RelaxedUnspecConstraint(warn_only=False) |
|
|
else: |
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
|
StrictMinMaxConstraint, |
|
|
) |
|
|
|
|
|
constraint_size = StrictMinMaxConstraint( |
|
|
vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), |
|
|
warn_only=False, |
|
|
) |
|
|
else: |
|
|
constraint_size = RelaxedUnspecConstraint(warn_only=False) |
|
|
elif marked_strict_unbacked: |
|
|
constraint_size = RelaxedUnspecConstraint(warn_only=False) |
|
|
elif not marked_static and automatic_dynamic: |
|
|
set_feature_use("dynamo.automatic_dynamic_shapes", True) |
|
|
if automatic_dynamic_size: |
|
|
constraint_size = RelaxedUnspecConstraint(warn_only=True) |
|
|
if automatic_dynamic_stride: |
|
|
constraint_stride = RelaxedUnspecConstraint(warn_only=True) |
|
|
else: |
|
|
if not marked_static and not config.automatic_dynamic_shapes: |
|
|
set_feature_use("dynamo.automatic_dynamic_shapes", False) |
|
|
constraint_size = None |
|
|
constraint_stride = None |
|
|
else: |
|
|
constraint_size, name_ = constraint |
|
|
constraint_stride = None |
|
|
dim_name = f"{name}.size()[{i}]" |
|
|
tx.output.shape_env.source_name_to_debug_name[dim_name] = name_ |
|
|
constraint_sizes.append(constraint_size) |
|
|
constraint_strides.append(constraint_stride) |
|
|
|
|
|
if marked_unbacked or is_unbacked_source(name): |
|
|
dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED |
|
|
elif ( |
|
|
constraint_size is not None |
|
|
or marked_dynamic |
|
|
or marked_weak_dynamic |
|
|
or is_nested_int(e.size()[i]) |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
if automatic_dynamic: |
|
|
dynamic_size = get_automatic_dynamic_shapes_mark_as() |
|
|
else: |
|
|
dynamic_size = DimDynamic.DYNAMIC |
|
|
elif static_shapes or config.assume_static_by_default or marked_static: |
|
|
dynamic_size = DimDynamic.STATIC |
|
|
else: |
|
|
|
|
|
dynamic_size = DimDynamic.DUCK |
|
|
|
|
|
if constraint_stride is not None: |
|
|
dynamic_stride = DimDynamic.DYNAMIC |
|
|
else: |
|
|
dynamic_stride = DimDynamic.INFER_STRIDE |
|
|
|
|
|
dynamic_sizes.append(dynamic_size) |
|
|
dynamic_strides.append(dynamic_stride) |
|
|
|
|
|
return StatefulSymbolicContext( |
|
|
dynamic_sizes=dynamic_sizes, |
|
|
dynamic_strides=dynamic_strides, |
|
|
constraint_sizes=constraint_sizes, |
|
|
constraint_strides=constraint_strides, |
|
|
specialize_on=specialize_on, |
|
|
view_base_context=view_base_context, |
|
|
tensor_source=source, |
|
|
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def wrap_to_fake_tensor_and_record( |
|
|
e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None |
|
|
): |
|
|
if ( |
|
|
type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) |
|
|
or isinstance(e, torch.Tensor) |
|
|
or is_traceable_wrapper_subclass(e) |
|
|
): |
|
|
assert source is not None |
|
|
static_shapes, _reason = tensor_always_has_static_shape( |
|
|
e, |
|
|
is_tensor, |
|
|
tensor_source=source, |
|
|
) |
|
|
|
|
|
if not parent_context: |
|
|
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(source, AttrSource) |
|
|
inner_context_name = source.member |
|
|
symbolic_context = parent_context.inner_contexts[inner_context_name] |
|
|
|
|
|
log.debug( |
|
|
"wrap_to_fake %s %s %s %s", |
|
|
source.name(), |
|
|
tuple(e.shape), |
|
|
symbolic_context, |
|
|
type(e), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with enable_python_dispatcher(): |
|
|
fake_e = wrap_fake_exception( |
|
|
lambda: tx.fake_mode.from_tensor( |
|
|
e, |
|
|
source=source, |
|
|
symbolic_context=symbolic_context, |
|
|
) |
|
|
) |
|
|
if ( |
|
|
source is not None |
|
|
and isinstance(fake_e, FakeTensor) |
|
|
and (sym_val := fake_e.item_memo) is not None |
|
|
): |
|
|
tx.output.tracked_fakes.append( |
|
|
TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) |
|
|
) |
|
|
|
|
|
if is_traceable_wrapper_subclass(fake_e): |
|
|
attrs, _ = fake_e.__tensor_flatten__() |
|
|
for attr in attrs: |
|
|
fake_inner = getattr(fake_e, attr) |
|
|
inner = getattr(e, attr) |
|
|
inner_source = AttrSource(source, attr) |
|
|
wrap_to_fake_tensor_and_record( |
|
|
inner, |
|
|
tx, |
|
|
source=inner_source, |
|
|
is_tensor=isinstance(fake_inner, torch.Tensor), |
|
|
parent_context=symbolic_context, |
|
|
) |
|
|
|
|
|
tx.output.tracing_context.tensor_to_context[e] = symbolic_context |
|
|
if is_sparse_any(fake_e): |
|
|
|
|
|
|
|
|
values = fake_e._values() if fake_e.is_sparse else fake_e.values() |
|
|
tx.output.input_source_to_sizes_strides[source] = { |
|
|
"size": fake_e.size(), |
|
|
|
|
|
|
|
|
"stride": (1,) * fake_e.ndim, |
|
|
"values_size": values.size(), |
|
|
"values_stride": values.stride(), |
|
|
} |
|
|
else: |
|
|
tx.output.input_source_to_sizes_strides[source] = { |
|
|
"size": fake_e.size(), |
|
|
"stride": fake_e.stride(), |
|
|
} |
|
|
|
|
|
if ( |
|
|
is_tensor |
|
|
and not (static_shapes and source.is_specialized_nn_module()) |
|
|
and not is_constant_source(source) |
|
|
): |
|
|
tx.output.tracked_fakes.append( |
|
|
TrackedFake(fake_e, source, symbolic_context) |
|
|
) |
|
|
tx.output.tracked_fakes_id_to_source[id(e)].append(source) |
|
|
|
|
|
return fake_e |
|
|
else: |
|
|
return e |
|
|
|
|
|
|
|
|
class SourcelessBuilder: |
|
|
""" |
|
|
Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects |
|
|
that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over |
|
|
.), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, |
|
|
there may be reasons to represent it as a ListVariable internally. |
|
|
|
|
|
NOTE - Objects produced here are born UNGUARDED due to the nature of sources! |
|
|
|
|
|
NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant |
|
|
if/else type->VariableTracker trees that were cropping up all over dynamo. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
raise AssertionError("Use SourcelessBuilder.create()") |
|
|
|
|
|
@staticmethod |
|
|
def create(tx: "InstructionTranslator", value) -> VariableTracker: |
|
|
value_type = type(value) |
|
|
fast_handler = SourcelessBuilder._type_handlers.get(value_type) |
|
|
if fast_handler: |
|
|
return fast_handler(tx, value) |
|
|
|
|
|
if isinstance(value, VariableTracker): |
|
|
|
|
|
return value |
|
|
elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): |
|
|
return UserDefinedObjectVariable(value) |
|
|
elif ConstantVariable.is_literal(value): |
|
|
return ConstantVariable.create(value) |
|
|
elif callable(value) and trace_rules.lookup_callable(value) is not None: |
|
|
if trace_rules.is_callable_allowed(value): |
|
|
tx.output.has_user_defined_allowed_in_graph = True |
|
|
return trace_rules.lookup_callable(value)(value) |
|
|
elif callable(value) and UserDefinedClassVariable.is_supported_new_method( |
|
|
value |
|
|
): |
|
|
|
|
|
obj = trace_rules.lookup_callable(value.__self__)(value.__self__) |
|
|
return GetAttrVariable(obj, "__new__") |
|
|
elif is_function_or_wrapper(value): |
|
|
return trace_rules.lookup(value)(value) |
|
|
elif isinstance( |
|
|
value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) |
|
|
): |
|
|
return EnumVariable(value) |
|
|
elif isinstance(value, (type, abc.ABCMeta)): |
|
|
return UserDefinedClassVariable(value) |
|
|
elif isinstance(value, types.MethodWrapperType): |
|
|
return MethodWrapperVariable(value) |
|
|
elif ( |
|
|
isinstance(value, types.MethodType) |
|
|
|
|
|
|
|
|
and isinstance(value.__self__, (type, abc.ABCMeta)) |
|
|
): |
|
|
|
|
|
assert getattr(value.__self__, value.__func__.__name__) == value |
|
|
cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) |
|
|
try: |
|
|
return cls_obj_vt.var_getattr(tx, value.__func__.__name__) |
|
|
except NotImplementedError: |
|
|
pass |
|
|
elif isinstance(value, torch.fx.graph_module.GraphModule): |
|
|
return SourcelessGraphModuleVariable(value) |
|
|
elif isinstance( |
|
|
value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec) |
|
|
): |
|
|
return UserDefinedObjectVariable(value) |
|
|
elif PlacementVariable.is_placement(value): |
|
|
return PlacementVariable(value) |
|
|
elif DeviceMeshVariable.is_device_mesh(value): |
|
|
return DeviceMeshVariable(value) |
|
|
elif value is functools.wraps: |
|
|
return FunctoolsWrapsVariable(value) |
|
|
elif isinstance(value, re.Pattern): |
|
|
return RegexPatternVariable(value) |
|
|
elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): |
|
|
return ConstantVariable.create(str(value)) |
|
|
elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)): |
|
|
return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( |
|
|
value |
|
|
) |
|
|
elif isinstance(value, types.GenericAlias): |
|
|
return TypingVariable(value) |
|
|
elif is_namedtuple(value): |
|
|
output = [ |
|
|
SourcelessBuilder.create(tx, getattr(value, name)) |
|
|
for name in namedtuple_fields(type(value)) |
|
|
] |
|
|
return NamedTupleVariable(output, tuple_cls=type(value)) |
|
|
elif ( |
|
|
isinstance(value, torch.SymInt) |
|
|
and value.node.expr in tx.output.bound_symbols |
|
|
): |
|
|
proxy = tx.output.bound_symbols[value.node.expr] |
|
|
return SymNodeVariable.create(tx, proxy) |
|
|
unimplemented_v2( |
|
|
gb_type="Unexpected type in sourceless builder", |
|
|
context=f"{value_type.__module__}.{value_type.__qualname__}", |
|
|
explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}", |
|
|
hints=[*graph_break_hints.DYNAMO_BUG], |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def wrap_constant_literal(value): |
|
|
assert ConstantVariable.is_literal(value) |
|
|
return ConstantVariable.create(value=value) |
|
|
|
|
|
@staticmethod |
|
|
def make_type_handlers(): |
|
|
create = SourcelessBuilder.create |
|
|
handlers = {} |
|
|
for t in common_constant_types: |
|
|
handlers[t] = lambda tx, value: ConstantVariable(value) |
|
|
handlers[set] = lambda tx, value: SetVariable( |
|
|
[create(tx, x) for x in value], mutation_type=ValueMutationNew() |
|
|
) |
|
|
handlers[dict] = lambda tx, value: ConstDictVariable( |
|
|
{create(tx, k): create(tx, v) for k, v in value.items()}, |
|
|
type(value), |
|
|
mutation_type=ValueMutationNew(), |
|
|
) |
|
|
handlers[list] = lambda tx, value: ListVariable( |
|
|
[create(tx, x) for x in value], mutation_type=ValueMutationNew() |
|
|
) |
|
|
handlers[tuple] = lambda tx, value: TupleVariable( |
|
|
[create(tx, x) for x in value] |
|
|
) |
|
|
handlers[torch.Size] = lambda tx, value: SizeVariable( |
|
|
[create(tx, x) for x in value] |
|
|
) |
|
|
handlers[collections.OrderedDict] = handlers[dict] |
|
|
handlers[immutable_dict] = handlers[dict] |
|
|
handlers[immutable_list] = handlers[list] |
|
|
handlers[random.Random] = lambda tx, value: RandomClassVariable() |
|
|
handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) |
|
|
|
|
|
handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = ( |
|
|
lambda tx, value: FuncTorchInterpreterVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
) |
|
|
|
|
|
handlers[torch.distributions.constraints._Real] = ( |
|
|
lambda tx, value: UserDefinedObjectVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
) |
|
|
handlers[torch.distributions.constraints._Interval] = ( |
|
|
lambda tx, value: UserDefinedObjectVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
) |
|
|
handlers[torch.distributions.constraints.Constraint] = ( |
|
|
lambda tx, value: UserDefinedObjectVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
) |
|
|
|
|
|
def passthrough(tx: "InstructionTranslator", value): |
|
|
return value |
|
|
|
|
|
for cls in VariableTrackerMeta.all_subclasses: |
|
|
handlers[cls] = passthrough |
|
|
return handlers |
|
|
|
|
|
|
|
|
SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() |
|
|
|
|
|
|
|
|
class SourcelessUserDefinedObjectBuilder: |
|
|
""" |
|
|
SourceLessBuilder does not return a UserDefinedObjectVariable, but in some |
|
|
cases it might be ok to return UserDefinedObjects. In such case, use this |
|
|
builder. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()") |
|
|
|
|
|
@staticmethod |
|
|
def create(tx: "InstructionTranslator", value) -> VariableTracker: |
|
|
value_type = type(value) |
|
|
if issubclass(value_type, MutableMapping): |
|
|
return MutableMappingVariable(value, mutation_type=ValueMutationNew()) |
|
|
elif isinstance(value, torch.nn.Module): |
|
|
return UnspecializedNNModuleVariable( |
|
|
value, mutation_type=ValueMutationNew() |
|
|
) |
|
|
else: |
|
|
return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew()) |
|
|
|