|
|
|
|
|
|
|
|
""" |
|
|
This module contains miscellaneous variable tracker implementations for various Python types |
|
|
and features used in Dynamo's symbolic execution. These classes help track and propagate |
|
|
information about different kinds of variables during graph capture. |
|
|
|
|
|
Key classes include: |
|
|
- SuperVariable: Handles super() calls and method resolution |
|
|
- ExceptionVariable: Tracks exception objects |
|
|
- RandomVariable: Manages random number generators |
|
|
- GetAttrVariable: Tracks attribute access |
|
|
- MethodWrapperVariable: Handles method wrappers |
|
|
- PythonModuleVariable: Tracks Python modules |
|
|
- NumpyVariable: Handles numpy functions and types |
|
|
- StringFormatVariable: Manages string formatting |
|
|
- DebuggingVariable: Handles print and logging |
|
|
""" |
|
|
|
|
|
import dataclasses |
|
|
import functools |
|
|
import inspect |
|
|
import itertools |
|
|
import random |
|
|
import re |
|
|
import sys |
|
|
import types |
|
|
import warnings |
|
|
from typing import Optional, TYPE_CHECKING |
|
|
|
|
|
import torch._C |
|
|
import torch._numpy as tnp |
|
|
import torch.utils._pytree as pytree |
|
|
|
|
|
from .. import config, graph_break_hints, trace_rules, variables |
|
|
from ..bytecode_transformation import create_call_function, create_instruction |
|
|
from ..create_parameter_op import do_not_convert_to_tracable_parameter |
|
|
from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 |
|
|
from ..guards import GuardBuilder, install_guard |
|
|
from ..mutation_guard import unpatched_nn_module_init |
|
|
from ..source import ( |
|
|
AttrSource, |
|
|
GenericAttrSource, |
|
|
GetItemSource, |
|
|
TypeMROSource, |
|
|
TypeSource, |
|
|
WeakRefCallSource, |
|
|
) |
|
|
from ..utils import ( |
|
|
check_unspec_or_constant_args, |
|
|
cmp_name_to_op_mapping, |
|
|
identity, |
|
|
is_tensor_base_attr_getter, |
|
|
istype, |
|
|
list_methods, |
|
|
proxy_args_kwargs, |
|
|
tuple_methods, |
|
|
) |
|
|
from .base import VariableTracker |
|
|
from .constant import ConstantVariable |
|
|
from .functions import NestedUserFunctionVariable, UserFunctionVariable |
|
|
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from torch._dynamo.codegen import PyCodegen |
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator |
|
|
|
|
|
|
|
|
class NO_SUCH_SUBOBJ: |
|
|
pass |
|
|
|
|
|
|
|
|
class SuperVariable(VariableTracker): |
|
|
_nonvar_fields = { |
|
|
*VariableTracker._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__(self, typevar, objvar=None, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.typevar = typevar |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.objvar = objvar |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) |
|
|
codegen(self.typevar) |
|
|
if self.objvar is not None: |
|
|
codegen(self.objvar) |
|
|
codegen.extend_output(create_call_function(2, False)) |
|
|
else: |
|
|
codegen.extend_output(create_call_function(1, False)) |
|
|
|
|
|
def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): |
|
|
assert self.objvar, "1-arg super not implemented" |
|
|
search_type = self.typevar.as_python_constant() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_to_use = self.objvar.python_type() |
|
|
type_to_use_source = ( |
|
|
TypeSource(self.objvar.source) if self.objvar.source else None |
|
|
) |
|
|
if issubclass(type_to_use, type): |
|
|
type_to_use = self.objvar.value |
|
|
type_to_use_source = self.objvar.source |
|
|
|
|
|
source = None |
|
|
search_mro = type_to_use.__mro__ |
|
|
|
|
|
try: |
|
|
start_index = search_mro.index(search_type) + 1 |
|
|
except ValueError: |
|
|
|
|
|
|
|
|
return getattr(super(search_type, type_to_use), name), None |
|
|
|
|
|
|
|
|
|
|
|
for index in range(start_index, len(search_mro)): |
|
|
|
|
|
if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): |
|
|
if resolved_getattr is not NO_SUCH_SUBOBJ: |
|
|
|
|
|
if type_to_use_source: |
|
|
source = AttrSource( |
|
|
GetItemSource(TypeMROSource(type_to_use_source), index), |
|
|
name, |
|
|
) |
|
|
return resolved_getattr, source |
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="Unable to resolve super getattr", |
|
|
context="", |
|
|
explanation=f"Dynamo failed to trace attribute `{name}` accessed " |
|
|
f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " |
|
|
"because the resolved attribute type is not supported.", |
|
|
hints=[ |
|
|
"Ensure the attribute exists in the parent class.", |
|
|
"Check the arguments passed to `super()`.", |
|
|
], |
|
|
) |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
value, source = self._resolved_getattr_and_source(self, name) |
|
|
if not variables.ConstantVariable.is_literal(value): |
|
|
return GetAttrVariable(self, name) |
|
|
if source: |
|
|
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) |
|
|
return variables.ConstantVariable.create(value, source=source) |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
inner_fn, source = self._resolved_getattr_and_source(self, name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inner_fn is object.__init__: |
|
|
return LambdaVariable(identity) |
|
|
elif inner_fn is torch.nn.Module.__init__: |
|
|
objvar = self.objvar |
|
|
from ..side_effects import AttributeMutationNew |
|
|
|
|
|
if ( |
|
|
isinstance(objvar, variables.UserDefinedObjectVariable) |
|
|
and isinstance(objvar.mutation_type, AttributeMutationNew) |
|
|
and not (args or kwargs) |
|
|
): |
|
|
with do_not_convert_to_tracable_parameter(): |
|
|
return variables.UserFunctionVariable( |
|
|
unpatched_nn_module_init, source=source |
|
|
).call_function(tx, [self.objvar] + args, kwargs) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported super().__init__() call", |
|
|
context=f"call_method {self} {name} {args} {kwargs}", |
|
|
explanation="Dynamo encountered a super().__init__() call " |
|
|
f"on {objvar} that resolved to a `torch.nn.Module.__init__()` " |
|
|
"call that we cannot trace.", |
|
|
hints=[*graph_break_hints.DIFFICULT], |
|
|
) |
|
|
elif ( |
|
|
self.objvar.source |
|
|
and hasattr(inner_fn, "__name__") |
|
|
and inner_fn.__name__ == "__new__" |
|
|
and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) |
|
|
): |
|
|
user_cls = inner_fn.__self__ |
|
|
if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": |
|
|
user_cls_vt = variables.BuiltinVariable(user_cls) |
|
|
else: |
|
|
user_cls_source = source.member |
|
|
user_cls_vt = variables.UserDefinedClassVariable( |
|
|
user_cls, source=user_cls_source |
|
|
) |
|
|
return user_cls_vt.call_method(tx, "__new__", args, kwargs) |
|
|
elif isinstance(inner_fn, staticmethod) and isinstance( |
|
|
inner_fn.__func__, types.FunctionType |
|
|
): |
|
|
return variables.UserFunctionVariable( |
|
|
inner_fn.__func__, source=source |
|
|
).call_function(tx, args, kwargs) |
|
|
elif isinstance(inner_fn, classmethod) and isinstance( |
|
|
inner_fn.__func__, types.FunctionType |
|
|
): |
|
|
if isinstance(self.objvar, variables.UserDefinedClassVariable): |
|
|
|
|
|
|
|
|
|
|
|
cls_variable = self.objvar |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cls_source = None |
|
|
if self.objvar.source: |
|
|
cls_source = TypeSource(self.objvar.source) |
|
|
cls_variable = VariableTracker.build( |
|
|
tx, self.objvar.value_type, cls_source |
|
|
) |
|
|
|
|
|
return variables.UserFunctionVariable( |
|
|
inner_fn.__func__, source=AttrSource(source, "__func__") |
|
|
).call_function(tx, [cls_variable, *args], kwargs) |
|
|
elif isinstance(inner_fn, types.FunctionType): |
|
|
return variables.UserFunctionVariable( |
|
|
inner_fn, source=source |
|
|
).call_function(tx, [self.objvar] + args, kwargs) |
|
|
elif isinstance(inner_fn, types.MethodType): |
|
|
return variables.UserMethodVariable( |
|
|
inner_fn.__func__, self.objvar, source=source |
|
|
).call_function(tx, args, kwargs) |
|
|
elif is_standard_setattr(inner_fn) and isinstance( |
|
|
self.objvar, UserDefinedObjectVariable |
|
|
): |
|
|
return self.objvar.method_setattr_standard(tx, *args, **kwargs) |
|
|
elif inner_fn is object.__delattr__: |
|
|
attr = args[0] |
|
|
try: |
|
|
attr = attr.as_python_constant() |
|
|
except NotImplementedError as exc: |
|
|
unimplemented_v2( |
|
|
gb_type="Non-constant attribute given to `super().__delattr__()`", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo requires the attribute name passed to " |
|
|
"`super().__delattr__(...)` to be a constant (string).", |
|
|
hints=[ |
|
|
"Ensure the attribute name is a string literal or a constant variable." |
|
|
], |
|
|
from_exc=exc, |
|
|
) |
|
|
if not tx.output.side_effects.is_attribute_mutation(self.objvar): |
|
|
unimplemented_v2( |
|
|
gb_type="Attempted super().__delattr__() on an object without mutation tracking", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo needs to track mutations on an object " |
|
|
"before `super().__delattr__` can be used on it. But the " |
|
|
f"object ({self.objvar}) doesn't have attribute mutation " |
|
|
"tracking enabled.", |
|
|
hints=[ |
|
|
"Ensure the object is tracked by Dynamo's side effect system.", |
|
|
*graph_break_hints.DYNAMO_BUG, |
|
|
], |
|
|
) |
|
|
|
|
|
tx.output.side_effects.store_attr( |
|
|
self.objvar, attr, variables.DeletedVariable() |
|
|
) |
|
|
return variables.ConstantVariable(None) |
|
|
elif ( |
|
|
isinstance(self.objvar, variables.UserDefinedDictVariable) |
|
|
and inner_fn in self.objvar._dict_methods |
|
|
): |
|
|
return self.objvar._dict_vt.call_method(tx, name, args, kwargs) |
|
|
elif ( |
|
|
isinstance(self.objvar, variables.UserDefinedSetVariable) |
|
|
and inner_fn in self.objvar._set_methods |
|
|
): |
|
|
return self.objvar._set_vt.call_method(tx, name, args, kwargs) |
|
|
elif ( |
|
|
isinstance(self.objvar, variables.UserDefinedTupleVariable) |
|
|
and inner_fn in tuple_methods |
|
|
): |
|
|
return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) |
|
|
elif ( |
|
|
isinstance(self.objvar, variables.UserDefinedListVariable) |
|
|
and inner_fn in list_methods |
|
|
): |
|
|
return self.objvar._list_vt.call_method(tx, name, args, kwargs) |
|
|
elif inner_fn is object.__getattribute__: |
|
|
|
|
|
|
|
|
attr_name = args[0].value |
|
|
if tx.output.side_effects.has_pending_mutation_of_attr( |
|
|
self.objvar, attr_name |
|
|
): |
|
|
result = tx.output.side_effects.load_attr( |
|
|
self.objvar, attr_name, deleted_ok=True |
|
|
) |
|
|
if isinstance(result, variables.DeletedVariable): |
|
|
raise_observed_exception(AttributeError, tx) |
|
|
return result |
|
|
|
|
|
try: |
|
|
|
|
|
attr_value = object.__getattribute__(self.objvar.value, attr_name) |
|
|
except AttributeError: |
|
|
raise_observed_exception(AttributeError, tx) |
|
|
|
|
|
attr_source = None |
|
|
if self.objvar.source is not None: |
|
|
|
|
|
attr_source = GenericAttrSource(self.objvar.source, attr_name) |
|
|
return VariableTracker.build(tx, attr_value, attr_source) |
|
|
elif inner_fn is torch._C._disabled_torch_function_impl: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func = args[0] |
|
|
tf_kwargs = {} |
|
|
tf_args = args[2].items |
|
|
for hash_key_vt, value_vt in args[3].items.items(): |
|
|
key_str = hash_key_vt.vt.as_python_constant() |
|
|
tf_kwargs[key_str] = value_vt |
|
|
|
|
|
tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled |
|
|
tx.symbolic_torch_function_state.torch_function_subclass_enabled = False |
|
|
try: |
|
|
return func.call_function(tx, tf_args, tf_kwargs) |
|
|
finally: |
|
|
tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( |
|
|
tx_old |
|
|
) |
|
|
elif ( |
|
|
isinstance(inner_fn, types.MethodDescriptorType) |
|
|
and inner_fn in trace_rules.get_tensor_method() |
|
|
): |
|
|
|
|
|
|
|
|
fn_var = VariableTracker.build(tx, inner_fn, source) |
|
|
return fn_var.call_function(tx, [self.objvar] + args, kwargs) |
|
|
|
|
|
unimplemented_v2( |
|
|
gb_type="Attempted to call a super() attribute that is " |
|
|
"not a function or method", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo does not know how to trace the call " |
|
|
f"`super().{name}()` because `super().{name}` is not a " |
|
|
"function or method attribute.", |
|
|
hints=[ |
|
|
"Ensure the attribute accessed via `super()` is a standard method or function.", |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
class ExceptionVariable(VariableTracker): |
|
|
|
|
|
def __init__(self, exc_type, args, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.exc_type = exc_type |
|
|
self.args = args |
|
|
|
|
|
|
|
|
|
|
|
self.__context__ = ConstantVariable(None) |
|
|
|
|
|
|
|
|
self.__cause__ = ConstantVariable(None) |
|
|
|
|
|
self.__suppress_context__ = ConstantVariable(False) |
|
|
|
|
|
|
|
|
self.__traceback__ = ConstantVariable(None) |
|
|
|
|
|
def set_context(self, context: "ExceptionVariable"): |
|
|
self.__context__ = context |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen.add_push_null( |
|
|
lambda: codegen.load_import_from("builtins", self.exc_type.__name__) |
|
|
) |
|
|
codegen.foreach(self.args) |
|
|
codegen.call_function(len(self.args), False) |
|
|
|
|
|
def codegen_attr(name: str) -> None: |
|
|
attr = getattr(self, name) |
|
|
if istype(attr, ConstantVariable): |
|
|
assert attr.value in (True, False, None), attr |
|
|
else: |
|
|
codegen.dup_top() |
|
|
codegen(attr) |
|
|
codegen.extend_output(codegen.rot_n(2)) |
|
|
codegen.store_attr(name) |
|
|
|
|
|
codegen_attr("__context__") |
|
|
codegen_attr("__cause__") |
|
|
codegen_attr("__suppress_context__") |
|
|
|
|
|
def python_type(self): |
|
|
return self.exc_type |
|
|
|
|
|
def call_setattr( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name_var: VariableTracker, |
|
|
val: VariableTracker, |
|
|
): |
|
|
def raise_error(msg): |
|
|
raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) |
|
|
|
|
|
name = name_var.as_python_constant() |
|
|
if name == "__context__": |
|
|
self.set_context(val) |
|
|
elif name == "__cause__": |
|
|
if (isinstance(val, ConstantVariable) and val.value is None) or isinstance( |
|
|
val, |
|
|
( |
|
|
variables.BuiltinVariable, |
|
|
variables.ExceptionVariable, |
|
|
variables.UserDefinedExceptionClassVariable, |
|
|
variables.UserDefinedExceptionObjectVariable, |
|
|
), |
|
|
): |
|
|
self.__cause__ = val |
|
|
self.__suppress_context__ = variables.ConstantVariable(True) |
|
|
else: |
|
|
raise_error("exception cause must be None or derive from BaseException") |
|
|
elif name == "__suppress_context__": |
|
|
if isinstance(val, ConstantVariable) and val.value in (True, False): |
|
|
self.__suppress_context__ = val |
|
|
else: |
|
|
raise_error("exception cause must be None or derive from BaseException") |
|
|
elif name == "__traceback__": |
|
|
if isinstance(val, ConstantVariable) and val.value is None: |
|
|
self.__traceback__ = val |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Set Exception object `__traceback__` attribute to not-`None`", |
|
|
context=f"call_setattr {self} {name}", |
|
|
explanation="Dynamo does not support setting the attribute " |
|
|
"'__traceback__' on tracked exception objects to anything " |
|
|
"other than None.", |
|
|
hints=[ |
|
|
"Avoid setting '__traceback__' on exception objects " |
|
|
"within traced code, or set it to None." |
|
|
], |
|
|
) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported attribute assignment on Exception object", |
|
|
context=f"call_setattr {self} {name}", |
|
|
explanation="Dynamo does not support setting the attribute " |
|
|
f"'{name}' on tracked exception objects. Only `__context__`, " |
|
|
"`__cause__`, `__suppress_context__`, and `__traceback__` are supported.", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
return variables.ConstantVariable(None) |
|
|
|
|
|
def call_method(self, tx, name, args, kwargs): |
|
|
if name == "__setattr__": |
|
|
return self.call_setattr(tx, *args) |
|
|
elif name == "with_traceback": |
|
|
[tb] = args |
|
|
self.call_setattr(tx, ConstantVariable("__traceback__"), tb) |
|
|
return self |
|
|
else: |
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
def var_getattr(self, tx, name): |
|
|
if name == "__context__": |
|
|
return self.__context__ |
|
|
elif name == "__cause__": |
|
|
return self.__cause__ |
|
|
elif name == "__suppress_context__": |
|
|
return self.__suppress_context__ |
|
|
elif name == "__traceback__": |
|
|
return variables.ConstantVariable(None) |
|
|
elif name == "args": |
|
|
return variables.ListVariable(self.args, source=self.source) |
|
|
return super().var_getattr(tx, name) |
|
|
|
|
|
def __str__(self): |
|
|
return f"{self.__class__.__name__}({self.exc_type})" |
|
|
|
|
|
__repr__ = __str__ |
|
|
|
|
|
|
|
|
class UnknownVariable(VariableTracker): |
|
|
""" |
|
|
It could be anything! |
|
|
""" |
|
|
|
|
|
|
|
|
class DelayGraphBreakVariable(UnknownVariable): |
|
|
""" |
|
|
Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. |
|
|
""" |
|
|
|
|
|
def __init__(self, msg=None, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.msg = msg |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported function call (delayed)", |
|
|
context=f"source: {self.source}", |
|
|
explanation="Dynamo determined that a graph break should occur " |
|
|
f"when calling `{self.source.name()}`. Reason: {self.msg}", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
|
|
|
class ComptimeVariable(VariableTracker): |
|
|
""" |
|
|
This variable is special, it lets you execute arbitrary code at |
|
|
Dynamo compile time |
|
|
""" |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
raise NotImplementedError("comptime is special form") |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": |
|
|
from ..comptime import comptime |
|
|
|
|
|
|
|
|
from .functions import UserFunctionVariable |
|
|
|
|
|
return UserFunctionVariable( |
|
|
getattr(comptime, name), source=AttrSource(self.source, name) |
|
|
) |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
from ..comptime import ComptimeContext |
|
|
|
|
|
|
|
|
|
|
|
assert not kwargs |
|
|
|
|
|
assert len(args) <= 2 |
|
|
fn = args[0] |
|
|
if isinstance(fn, UserFunctionVariable): |
|
|
fn.get_function()(ComptimeContext(tx)) |
|
|
elif isinstance(fn, NestedUserFunctionVariable): |
|
|
|
|
|
code = fn.get_code() |
|
|
assert not fn.closure, ( |
|
|
"comptime function must not have free variables, " |
|
|
f"but these variables were free: {code.co_freevars}" |
|
|
) |
|
|
func = types.FunctionType( |
|
|
code, |
|
|
fn.f_globals, |
|
|
fn.fn_name.as_python_constant(), |
|
|
tuple(fn.defaults.items) if fn.defaults else None, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(), |
|
|
) |
|
|
func(ComptimeContext(tx)) |
|
|
else: |
|
|
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") |
|
|
|
|
|
return variables.ConstantVariable.create(None) |
|
|
|
|
|
|
|
|
class CellVariable(VariableTracker): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre_existing_contents: Optional[VariableTracker] |
|
|
|
|
|
|
|
|
|
|
|
local_name: Optional[str] = None |
|
|
|
|
|
def __init__( |
|
|
self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.pre_existing_contents = pre_existing_contents |
|
|
|
|
|
|
|
|
class NewGlobalVariable(VariableTracker): |
|
|
def __init__(self, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
def produce_trampoline_autograd_apply(fn_cls): |
|
|
def trampoline_autograd_apply(*args, **kwargs): |
|
|
return fn_cls.apply(*args, **kwargs) |
|
|
|
|
|
trampoline_autograd_apply._origin = produce_trampoline_autograd_apply |
|
|
return trampoline_autograd_apply |
|
|
|
|
|
|
|
|
class AutogradFunctionVariable(VariableTracker): |
|
|
"""represents a torch.autograd.Function subclass""" |
|
|
|
|
|
_nonvar_fields = { |
|
|
"fn_cls", |
|
|
*VariableTracker._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__(self, fn_cls, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.fn_cls = fn_cls |
|
|
|
|
|
def call_apply(self, tx: "InstructionTranslator", args, kwargs): |
|
|
requires_grad = False |
|
|
|
|
|
def visit(vt): |
|
|
nonlocal requires_grad |
|
|
if isinstance(vt, variables.TensorVariable): |
|
|
if vt.requires_grad is not False: |
|
|
requires_grad = True |
|
|
if isinstance(vt, variables.NNModuleVariable): |
|
|
if vt.is_training(tx): |
|
|
requires_grad = True |
|
|
|
|
|
VariableTracker.visit(visit, (args, kwargs)) |
|
|
|
|
|
if requires_grad and torch.is_grad_enabled(): |
|
|
if config.capture_autograd_function is False: |
|
|
warnings.warn( |
|
|
"The config.capture_autograd_function flag is deprecated, it's now always true." |
|
|
) |
|
|
|
|
|
from torch._functorch.autograd_function import ( |
|
|
autograd_function_forward_rewritten, |
|
|
) |
|
|
from torch.autograd.function import _is_setup_context_defined |
|
|
|
|
|
forward_fn = self.fn_cls.forward |
|
|
|
|
|
is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) |
|
|
if is_setup_ctx_defined: |
|
|
|
|
|
|
|
|
forward_fn = autograd_function_forward_rewritten( |
|
|
self.fn_cls.forward, self.fn_cls.setup_context |
|
|
) |
|
|
|
|
|
vjp_fn = self.fn_cls.vjp |
|
|
if vjp_fn is not torch.autograd.Function.vjp: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported custom vjp", |
|
|
context=f"call_apply {self} {args} {kwargs}", |
|
|
explanation="Dynamo does not support tracing " |
|
|
"`torch.autograd.Function` subclasses that define " |
|
|
"a custom `vjp` method.", |
|
|
hints=[ |
|
|
"Remove the custom `vjp` method if possible.", |
|
|
"Use standard `backward` instead if applicable.", |
|
|
*graph_break_hints.SUPPORTABLE, |
|
|
], |
|
|
) |
|
|
|
|
|
jvp_fn = self.fn_cls.jvp |
|
|
if jvp_fn is not torch.autograd.Function.jvp: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported custom jvp", |
|
|
context=f"call_apply {self} {args} {kwargs}", |
|
|
explanation="Dynamo does not support tracing " |
|
|
"`torch.autograd.Function` subclasses that define " |
|
|
"a custom `jvp` method.", |
|
|
hints=[ |
|
|
"Remove the custom `jvp` method if possible.", |
|
|
*graph_break_hints.SUPPORTABLE, |
|
|
], |
|
|
) |
|
|
|
|
|
from .higher_order_ops import AutogradFunctionApplyVariable |
|
|
|
|
|
source = self.source |
|
|
if source is None: |
|
|
source = AttrSource( |
|
|
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ |
|
|
) |
|
|
|
|
|
val = AutogradFunctionApplyVariable( |
|
|
forward_fn, |
|
|
self.fn_cls.backward, |
|
|
source, |
|
|
source=AttrSource(source, member="apply"), |
|
|
).call_function(tx, args, kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.source: |
|
|
fwd_src = AttrSource(self.source, "forward") |
|
|
install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH)) |
|
|
if is_setup_ctx_defined: |
|
|
setup_ctx_src = AttrSource(self.source, "setup_context") |
|
|
install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH)) |
|
|
|
|
|
return val |
|
|
|
|
|
if self.source: |
|
|
source = AttrSource(self.source, "forward") |
|
|
else: |
|
|
source = None |
|
|
|
|
|
fn = self.fn_cls.forward |
|
|
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) |
|
|
args = [ctx, *args] |
|
|
if isinstance(fn, types.FunctionType): |
|
|
sig = inspect.signature(fn) |
|
|
if len(args) - 1 == len(sig._parameters): |
|
|
args = args[1:] |
|
|
return variables.UserFunctionVariable(fn, source=source).call_function( |
|
|
tx, args, kwargs |
|
|
) |
|
|
elif isinstance(fn, types.MethodType): |
|
|
return variables.UserMethodVariable( |
|
|
fn.__func__, |
|
|
variables.UserDefinedClassVariable(self.fn_cls), |
|
|
source=source, |
|
|
).call_function(tx, args, kwargs) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Non-function or method in subclass of torch.autograd.Function", |
|
|
context=f"call_apply {self} {args} {kwargs}", |
|
|
explanation="Dynamo requires the `forward` attribute of a " |
|
|
"`torch.autograd.Function` subclass to be a standard Python " |
|
|
f"function or method. Found type `{type(fn).__name__}` instead.", |
|
|
hints=[ |
|
|
"Ensure the `forward` method is defined as a regular " |
|
|
"function or instance method." |
|
|
], |
|
|
) |
|
|
|
|
|
def call_backward(self, tx: "InstructionTranslator", args, kwargs): |
|
|
fn = self.fn_cls.backward |
|
|
assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction |
|
|
assert isinstance(fn, types.FunctionType) |
|
|
|
|
|
fn_source = AttrSource(self.source, "backward") |
|
|
return variables.UserFunctionVariable(fn, source=fn_source).call_function( |
|
|
tx, args, kwargs |
|
|
) |
|
|
|
|
|
def call_function(self, tx: "InstructionTranslator", args, kwargs): |
|
|
return AutogradFunctionVariable(self.fn_cls) |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
): |
|
|
from .builder import wrap_fx_proxy |
|
|
|
|
|
if name == "apply": |
|
|
if trace_rules.is_callable_allowed(self.fn_cls): |
|
|
trampoline_autograd_apply = produce_trampoline_autograd_apply( |
|
|
self.fn_cls |
|
|
) |
|
|
return wrap_fx_proxy( |
|
|
tx=tx, |
|
|
proxy=tx.output.create_proxy( |
|
|
"call_function", |
|
|
trampoline_autograd_apply, |
|
|
*proxy_args_kwargs(args, kwargs), |
|
|
), |
|
|
) |
|
|
else: |
|
|
return self.call_apply(tx, args, kwargs) |
|
|
|
|
|
elif name == "backward": |
|
|
return self.call_backward(tx, args, kwargs) |
|
|
else: |
|
|
source = AttrSource(self.source, name) if self.source is not None else None |
|
|
try: |
|
|
obj = inspect.getattr_static(self.fn_cls, name) |
|
|
except AttributeError: |
|
|
obj = None |
|
|
|
|
|
if isinstance(obj, staticmethod): |
|
|
func = obj.__get__(self.fn_cls) |
|
|
if source is not None: |
|
|
return ( |
|
|
trace_rules.lookup(func) |
|
|
.create_with_source(func, source=source) |
|
|
.call_function(tx, args, kwargs) |
|
|
) |
|
|
else: |
|
|
return trace_rules.lookup(func)(func).call_function( |
|
|
tx, args, kwargs |
|
|
) |
|
|
elif isinstance(obj, classmethod): |
|
|
return variables.UserMethodVariable( |
|
|
obj.__func__, self, source=source |
|
|
).call_function(tx, args, kwargs) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported autograd.Function method", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo does not support calling the method " |
|
|
f"`{name}` directly on the `torch.autograd.Function` " |
|
|
"instance. Supported methods include `apply`, `backward`, " |
|
|
"static methods, and class methods.", |
|
|
hints=[ |
|
|
"Ensure the method is decorated with `@staticmethod` " |
|
|
"or `@classmethod` if it's meant to be called on the class.", |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class SavedTensorBox: |
|
|
tensors: list[VariableTracker] = dataclasses.field(default_factory=list) |
|
|
|
|
|
|
|
|
class AutogradFunctionContextVariable(UserDefinedObjectVariable): |
|
|
""" |
|
|
Tracks an autograd.Function() context using mutation tracking in side_effects.py |
|
|
""" |
|
|
|
|
|
_nonvar_fields = { |
|
|
"proxy", |
|
|
"inference", |
|
|
"saved_tensors", |
|
|
*UserDefinedObjectVariable._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
value, |
|
|
value_type=None, |
|
|
inference=False, |
|
|
saved_tensors=None, |
|
|
needs_input_grad=None, |
|
|
non_differentiable=None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(value=value, value_type=value_type, **kwargs) |
|
|
self.inference = inference |
|
|
self.saved_tensors = saved_tensors |
|
|
self.needs_input_grad = needs_input_grad |
|
|
self.non_differentiable = non_differentiable |
|
|
|
|
|
@staticmethod |
|
|
def create(tx: "InstructionTranslator", args=None, kwargs=None): |
|
|
needs_input_grad = None |
|
|
if args and not kwargs: |
|
|
needs_input_grad = tuple( |
|
|
isinstance(x, variables.TensorVariable) and x.requires_grad |
|
|
for x in args |
|
|
) |
|
|
out = tx.output.side_effects.track_object_new( |
|
|
None, |
|
|
torch.autograd.function.FunctionCtx, |
|
|
functools.partial( |
|
|
AutogradFunctionContextVariable, |
|
|
inference=True, |
|
|
saved_tensors=SavedTensorBox(), |
|
|
needs_input_grad=needs_input_grad, |
|
|
), |
|
|
{}, |
|
|
) |
|
|
return out |
|
|
|
|
|
def as_proxy(self): |
|
|
if self.proxy is None: |
|
|
unimplemented_v2( |
|
|
gb_type="proxy not set", |
|
|
context=f"as_proxy {self}", |
|
|
explanation="Dynamo requires the autograd.Function context " |
|
|
"to be initialized with a proxy.", |
|
|
hints=[*graph_break_hints.DYNAMO_BUG], |
|
|
) |
|
|
return self.proxy |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
if name == "__setattr__": |
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
elif name == "mark_non_differentiable": |
|
|
assert len(kwargs) == 0 |
|
|
self.non_differentiable = proxy_args_kwargs(args, {})[0] |
|
|
return variables.ConstantVariable.create(None) |
|
|
|
|
|
if name != "save_for_backward": |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported autograd.Function context method", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo does not support calling the method " |
|
|
f"`{name}` on `autograd.Function` context objects. Supported " |
|
|
"methods are `__setattr__`, `save_for_backward` and " |
|
|
"`mark_non_differentiable`.", |
|
|
hints=[*graph_break_hints.SUPPORTABLE], |
|
|
) |
|
|
if self.saved_tensors is None: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported autograd.Function context `save_for_backward`", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo requires the `saved_tensors` attribute " |
|
|
"to be initialized on the `autograd.Function` context object.", |
|
|
hints=[ |
|
|
"Ensure that the `saved_tensors` attribute is properly " |
|
|
"initialized before calling `save_for_backward`. " |
|
|
"`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.", |
|
|
], |
|
|
) |
|
|
|
|
|
if not self.inference: |
|
|
assert self.source and not kwargs |
|
|
tx.output.side_effects.track_save_for_backward(self, args) |
|
|
|
|
|
|
|
|
if len(self.saved_tensors.tensors) > 0: |
|
|
self.saved_tensors.tensors = [] |
|
|
for arg in args: |
|
|
self.saved_tensors.tensors.append(arg) |
|
|
return variables.ConstantVariable.create(None) |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name): |
|
|
if name in ["save_for_backward", "mark_non_differentiable"]: |
|
|
return LambdaVariable( |
|
|
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) |
|
|
) |
|
|
if name == "saved_tensors" and self.saved_tensors is not None: |
|
|
return variables.TupleVariable(list(self.saved_tensors.tensors)) |
|
|
if name == "needs_input_grad": |
|
|
if self.needs_input_grad is not None: |
|
|
return variables.ConstantVariable.create(self.needs_input_grad) |
|
|
if self.source: |
|
|
source = AttrSource(self.source, "needs_input_grad") |
|
|
return VariableTracker.build(tx, self.value.needs_input_grad, source) |
|
|
|
|
|
return super().var_getattr(tx, name) |
|
|
|
|
|
|
|
|
class AutogradEngineVariable(UserDefinedObjectVariable): |
|
|
""" |
|
|
Represents a torch._C._ImperativeEngine instance. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
value, |
|
|
value_type=None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(value=value, value_type=value_type, **kwargs) |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
if name == "queue_callback": |
|
|
if torch._dynamo.compiled_autograd.in_compiled_autograd_region: |
|
|
assert tx.one_graph or tx.error_on_graph_break, ( |
|
|
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" |
|
|
) |
|
|
return variables.UserFunctionVariable( |
|
|
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, |
|
|
source=self.source, |
|
|
).call_function( |
|
|
tx, |
|
|
(tx.output.side_effects.get_ca_final_callbacks_var(), *args), |
|
|
kwargs, |
|
|
) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="queue_callback() is only supported when " |
|
|
"Compiled Autograd is enabled with fullgraph=True.", |
|
|
hints=[], |
|
|
) |
|
|
else: |
|
|
unimplemented_v2( |
|
|
gb_type="Unsupported torch._C._ImperativeEngine method", |
|
|
context=f"call_method {self} {name}", |
|
|
explanation="Dynamo only supports the `queue_callback` method " |
|
|
f"on a torch._C._ImperativeEngine instance, but found: `{name}`.", |
|
|
hints=[], |
|
|
) |
|
|
|
|
|
|
|
|
class LambdaVariable(VariableTracker): |
|
|
def __init__(self, fn, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.fn = fn |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
return self.fn(*args, **kwargs) |
|
|
|
|
|
|
|
|
class GetAttrVariable(VariableTracker): |
|
|
_nonvar_fields = { |
|
|
"name", |
|
|
"py_type", |
|
|
*VariableTracker._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__(self, obj, name, py_type=None, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
assert isinstance(obj, VariableTracker) |
|
|
assert isinstance(name, str) |
|
|
self.obj = obj |
|
|
self.name = name |
|
|
self.py_type = py_type |
|
|
|
|
|
def python_type(self): |
|
|
if self.py_type is not None: |
|
|
return self.py_type |
|
|
else: |
|
|
return super().python_type() |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"{self.__class__.__name__}({self.obj}, {self.name})" |
|
|
|
|
|
@staticmethod |
|
|
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): |
|
|
return getattr(base_proxy, attr) |
|
|
|
|
|
def as_proxy(self): |
|
|
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) |
|
|
|
|
|
def as_python_constant(self): |
|
|
constant = self.obj.as_python_constant() |
|
|
try: |
|
|
return getattr(constant, self.name) |
|
|
except AttributeError: |
|
|
raise NotImplementedError(f"{self} is not a constant") from None |
|
|
|
|
|
def const_getattr(self, tx: "InstructionTranslator", name): |
|
|
if not isinstance(self.obj, variables.NNModuleVariable): |
|
|
raise NotImplementedError |
|
|
step1 = tx.output.get_submodule(self.obj.module_key) |
|
|
if self.name not in step1.__dict__: |
|
|
raise NotImplementedError |
|
|
step2 = inspect.getattr_static(step1, self.name) |
|
|
if name not in step2.__dict__: |
|
|
raise NotImplementedError |
|
|
return inspect.getattr_static(step2, name) |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen(self.obj) |
|
|
codegen.extend_output(codegen.create_load_attrs(self.name)) |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
return self.obj.call_method(tx, self.name, args, kwargs) |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: list[VariableTracker], |
|
|
kwargs: dict[str, VariableTracker], |
|
|
) -> VariableTracker: |
|
|
if ( |
|
|
name in ("__getitem__", "get") |
|
|
and self.name == "__dict__" |
|
|
and not kwargs |
|
|
and args[0].is_python_constant() |
|
|
and isinstance( |
|
|
self.obj, |
|
|
( |
|
|
variables.UserDefinedObjectVariable, |
|
|
variables.NNModuleVariable, |
|
|
variables.UserDefinedClassVariable, |
|
|
), |
|
|
) |
|
|
): |
|
|
obj = self.obj |
|
|
key = args[0].as_python_constant() |
|
|
if obj.has_key_in_generic_dict(tx, key): |
|
|
|
|
|
return obj.var_getattr(tx, key) |
|
|
|
|
|
|
|
|
if name == "get": |
|
|
if len(args) == 2: |
|
|
return args[1] |
|
|
else: |
|
|
return variables.ConstantVariable(None) |
|
|
|
|
|
elif ( |
|
|
name == "__contains__" |
|
|
and self.name == "__dict__" |
|
|
and len(args) == 1 |
|
|
and args[0].is_python_constant() |
|
|
and not kwargs |
|
|
and isinstance( |
|
|
self.obj, |
|
|
( |
|
|
variables.UserDefinedObjectVariable, |
|
|
variables.NNModuleVariable, |
|
|
variables.UserDefinedClassVariable, |
|
|
), |
|
|
) |
|
|
): |
|
|
obj = self.obj |
|
|
key = args[0].as_python_constant() |
|
|
if obj.has_key_in_generic_dict(tx, key): |
|
|
return variables.ConstantVariable(True) |
|
|
else: |
|
|
return variables.ConstantVariable(False) |
|
|
|
|
|
elif name == "__setitem__" and self.name == "__dict__" and not kwargs: |
|
|
if isinstance(self.obj, variables.UserDefinedObjectVariable): |
|
|
|
|
|
return self.obj.method_setattr_standard( |
|
|
tx, args[0], args[1], directly_update_dict=True |
|
|
) |
|
|
if isinstance(self.obj, variables.NNModuleVariable): |
|
|
|
|
|
self.obj.convert_to_unspecialized(tx) |
|
|
|
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
def get_forwarded_dict(self, tx): |
|
|
assert ( |
|
|
self.name == "__dict__" |
|
|
and isinstance(self.obj, variables.UserDefinedClassVariable) |
|
|
and not tx.output.side_effects.has_pending_mutation(self.obj) |
|
|
) |
|
|
self.obj.ban_mutation = True |
|
|
return VariableTracker.build(tx, self.obj.value.__dict__, self.source) |
|
|
|
|
|
|
|
|
class MethodWrapperVariable(VariableTracker): |
|
|
def __init__(self, method_wrapper, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.method_wrapper = method_wrapper |
|
|
self._builtin_fns = {} |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( |
|
|
args[0], variables.TensorVariable |
|
|
): |
|
|
assert len(args) == 1 and len(kwargs) == 0 |
|
|
|
|
|
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_obj = self.method_wrapper.__self__ |
|
|
wrapper_name = self.method_wrapper.__name__ |
|
|
|
|
|
|
|
|
if wrapper_name == "__init__": |
|
|
fn_obj = type(self_obj).__init__ |
|
|
if fn_obj is object.__init__: |
|
|
return variables.BuiltinVariable(object).call_method( |
|
|
tx, wrapper_name, [self_obj, *args], kwargs |
|
|
) |
|
|
|
|
|
return super().call_function(tx, args, kwargs) |
|
|
|
|
|
def is_python_constant(self): |
|
|
return True |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.method_wrapper |
|
|
|
|
|
|
|
|
class GetSetDescriptorVariable(VariableTracker): |
|
|
def __init__(self, desc, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.desc = desc |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name): |
|
|
if name == "__get__" and self.source: |
|
|
source = AttrSource(self.source, "__get__") |
|
|
return VariableTracker.build(tx, self.desc.__get__, source) |
|
|
else: |
|
|
return super().var_getattr(tx, name) |
|
|
|
|
|
def is_python_constant(self): |
|
|
return True |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.desc |
|
|
|
|
|
|
|
|
class PythonModuleVariable(VariableTracker): |
|
|
_nonvar_fields = { |
|
|
"value", |
|
|
"is_torch", |
|
|
*VariableTracker._nonvar_fields, |
|
|
} |
|
|
|
|
|
def __init__(self, value: types.ModuleType, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") |
|
|
|
|
|
def python_type(self): |
|
|
return types.ModuleType |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.value |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"PythonModuleVariable({self.value})" |
|
|
|
|
|
def call_obj_hasattr(self, tx: "InstructionTranslator", name): |
|
|
result = hasattr(self.value, name) |
|
|
return variables.ConstantVariable.create(result) |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name): |
|
|
if tx.output.side_effects.has_pending_mutation_of_attr(self, name): |
|
|
return tx.output.side_effects.load_attr(self, name) |
|
|
|
|
|
if self.is_torch or name not in self.value.__dict__: |
|
|
try: |
|
|
attr_value = getattr(self.value, name) |
|
|
except AttributeError: |
|
|
raise_observed_exception(AttributeError, tx) |
|
|
else: |
|
|
attr_value = self.value.__dict__[name] |
|
|
|
|
|
source = self.source and AttrSource(self.source, name) |
|
|
return VariableTracker.build(tx, attr_value, source) |
|
|
|
|
|
|
|
|
class TypingVariable(VariableTracker): |
|
|
def __init__(self, value, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
|
|
|
if name == "__getitem__" and len(args) == 1: |
|
|
new_typing = self.value[args[0].as_python_constant()] |
|
|
return TypingVariable(new_typing) |
|
|
unimplemented("unsupported method call on typing variablel") |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name: str): |
|
|
from .builder import SourcelessBuilder, VariableBuilder |
|
|
|
|
|
if name in cmp_name_to_op_mapping: |
|
|
return variables.GetAttrVariable(self, name) |
|
|
|
|
|
if tx.output.side_effects.has_pending_mutation_of_attr(self, name): |
|
|
return tx.side_effects.load_attr(self, name) |
|
|
|
|
|
value = getattr(self.value, name) |
|
|
if self.source: |
|
|
attr_source = AttrSource(self.source, name) |
|
|
return VariableBuilder(tx, attr_source)(value) |
|
|
else: |
|
|
return SourcelessBuilder.create(tx, value) |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.value |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codegen.append_output(codegen.create_load_const(self.value)) |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
|
def get_np_to_tnp_map(): |
|
|
""" |
|
|
This generates a mapping from numpy modules to their torch._numpy |
|
|
modules equivalents. |
|
|
""" |
|
|
from ..utils import NP_TO_TNP_MODULE |
|
|
|
|
|
np_fn_to_tnp_fn = {} |
|
|
|
|
|
for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): |
|
|
for fn_name, tnp_fn in tnp_mod.__dict__.items(): |
|
|
if callable(tnp_fn): |
|
|
|
|
|
|
|
|
if np_fn := getattr(np_mod, fn_name, None): |
|
|
np_fn_to_tnp_fn[np_fn] = tnp_fn |
|
|
|
|
|
return np_fn_to_tnp_fn |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
|
def get_tnp_to_np_map(): |
|
|
""" |
|
|
This is just the reverse mapping of get_np_to_tnp_map() - mapping from |
|
|
torch._numpy modules to numpy equivalents. |
|
|
""" |
|
|
m = get_np_to_tnp_map() |
|
|
return {v: k for k, v in m.items()} |
|
|
|
|
|
|
|
|
class NumpyVariable(VariableTracker): |
|
|
""" |
|
|
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. |
|
|
""" |
|
|
|
|
|
constant_fold_functions = (tnp.issubdtype,) |
|
|
|
|
|
def __init__(self, value, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
|
|
|
@classmethod |
|
|
def can_constant_fold_through(cls, fn): |
|
|
mod = fn.__module__.split(".") |
|
|
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] |
|
|
return fn in cls.constant_fold_functions |
|
|
|
|
|
@classmethod |
|
|
def get_constant_collection_for_func(cls, fn): |
|
|
mod = fn.__module__.split(".") |
|
|
assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] |
|
|
return np_constant_collections_map.get(fn, None) |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
if not config.trace_numpy: |
|
|
unimplemented(f"numpy.{self.value}()") |
|
|
|
|
|
from ..utils import numpy_to_tensor_wrapper |
|
|
from .tensor import NumpyNdarrayVariable |
|
|
|
|
|
func = get_np_to_tnp_map().get(self.value) |
|
|
if func is None: |
|
|
unimplemented( |
|
|
f"Can't find numpy function {self.value} in torch._numpy. " |
|
|
" Please file an issue to request support for this function." |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
collection_variable_typ := self.get_constant_collection_for_func(func) |
|
|
) is not None: |
|
|
try: |
|
|
return collection_variable_typ( |
|
|
self.value( |
|
|
*[x.as_python_constant() for x in args], |
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()}, |
|
|
) |
|
|
) |
|
|
except NotImplementedError: |
|
|
unimplemented( |
|
|
f"{self.value.__name__} with non-const args: {args} {kwargs}" |
|
|
) |
|
|
else: |
|
|
if ( |
|
|
func.__module__ == "torch._numpy.random" |
|
|
and config.use_numpy_random_stream |
|
|
): |
|
|
msg = f"delegate '{func.__qualname__}' to NumPy itself via " |
|
|
msg += ( |
|
|
f"config.use_numpy_random_stream={config.use_numpy_random_stream}" |
|
|
) |
|
|
unimplemented(msg) |
|
|
|
|
|
args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) |
|
|
|
|
|
if self.can_constant_fold_through(func) and ( |
|
|
check_unspec_or_constant_args(args, kwargs) |
|
|
): |
|
|
|
|
|
return variables.ConstantVariable.create( |
|
|
self.as_python_constant()( |
|
|
*[x.as_python_constant() for x in args], |
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()}, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
proxy = tx.output.create_proxy( |
|
|
"call_function", |
|
|
numpy_to_tensor_wrapper(func), |
|
|
*proxy_args_kwargs(args, kwargs), |
|
|
) |
|
|
return NumpyNdarrayVariable.create(tx, proxy) |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
unimplemented("numpy") |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.value |
|
|
|
|
|
def as_proxy(self): |
|
|
if config.trace_numpy and isinstance(self.value, type): |
|
|
|
|
|
|
|
|
|
|
|
return self.value.__name__ |
|
|
|
|
|
return super().as_proxy() |
|
|
|
|
|
|
|
|
|
|
|
class NullVariable(VariableTracker): |
|
|
def __init__(self, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return "NullVariable" |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
if sys.version_info < (3, 11): |
|
|
unimplemented("cannot reconstruct NullVariable in < Python 3.11") |
|
|
codegen.append_output(create_instruction("PUSH_NULL")) |
|
|
|
|
|
|
|
|
class DeletedVariable(VariableTracker): |
|
|
"""Marker used to implement delattr()""" |
|
|
|
|
|
|
|
|
class StringFormatVariable(VariableTracker): |
|
|
""" |
|
|
Represents a call to str.format(), we delay calling format until after the graph. |
|
|
""" |
|
|
|
|
|
_nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} |
|
|
|
|
|
@classmethod |
|
|
def create(cls, format_string, sym_args, sym_kwargs): |
|
|
if all( |
|
|
x.is_python_constant() |
|
|
for x in itertools.chain(sym_args, sym_kwargs.values()) |
|
|
): |
|
|
return variables.ConstantVariable.create( |
|
|
format_string.format( |
|
|
*[v.as_python_constant() for v in sym_args], |
|
|
**{k: v.as_python_constant() for k, v in sym_kwargs.items()}, |
|
|
) |
|
|
) |
|
|
return cls(format_string, list(sym_args), dict(sym_kwargs)) |
|
|
|
|
|
def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
assert isinstance(format_string, str) |
|
|
self.format_string = format_string |
|
|
self.sym_args = sym_args |
|
|
self.sym_kwargs = sym_kwargs |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen.add_push_null( |
|
|
lambda: codegen.extend_output( |
|
|
[ |
|
|
codegen.create_load_const(self.format_string), |
|
|
codegen.create_load_attr("format"), |
|
|
] |
|
|
), |
|
|
call_function_ex=True, |
|
|
) |
|
|
codegen(variables.TupleVariable(self.sym_args)) |
|
|
kwargs = { |
|
|
variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() |
|
|
} |
|
|
codegen(variables.ConstDictVariable(kwargs)) |
|
|
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) |
|
|
|
|
|
|
|
|
class DebuggingVariable(VariableTracker): |
|
|
""" |
|
|
Represents a call to a debugging function like print(), or something |
|
|
registered to config.reorderable_logging_functions. |
|
|
""" |
|
|
|
|
|
def __init__(self, value, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
|
|
|
@staticmethod |
|
|
def is_reorderable_logging_function(obj): |
|
|
return ( |
|
|
callable(obj) |
|
|
and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) |
|
|
and obj in torch._dynamo.config.reorderable_logging_functions |
|
|
) |
|
|
|
|
|
def call_function(self, tx: "InstructionTranslator", args, kwargs): |
|
|
if tx.export: |
|
|
|
|
|
return |
|
|
|
|
|
if not self.can_reorder_logs(self.value, args, kwargs): |
|
|
unimplemented( |
|
|
f"Reordering debugging function {self.value} " |
|
|
f"with inputs {args} {kwargs} is not yet implemented." |
|
|
) |
|
|
|
|
|
tx.debug_locals.append((self, list(args))) |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
return self.source.reconstruct(codegen) |
|
|
|
|
|
@staticmethod |
|
|
def can_reorder_logs(fn, args, kwargs) -> True: |
|
|
""" |
|
|
Run some additional checks for what sort of function calls can we |
|
|
actually reorder. |
|
|
""" |
|
|
|
|
|
allowed_input_types = ( |
|
|
variables.TensorVariable, |
|
|
variables.ConstantVariable, |
|
|
StringFormatVariable, |
|
|
) |
|
|
|
|
|
flat_args = pytree.tree_leaves([args, kwargs]) |
|
|
for arg in flat_args: |
|
|
if not isinstance(arg, allowed_input_types): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
class LoggingLoggerVariable(VariableTracker): |
|
|
""" |
|
|
Represents a call to any of logging.Logger methods |
|
|
""" |
|
|
|
|
|
def __init__(self, value, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
if tx.export: |
|
|
|
|
|
return |
|
|
method = getattr(self.value, name, None) |
|
|
function = getattr(method, "__func__", None) |
|
|
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): |
|
|
return variables.ConstantVariable.create(None) |
|
|
unimplemented( |
|
|
"Logger not supported for non-export cases. " |
|
|
"To avoid graph breaks caused by logger in compile-mode, it is recommended to" |
|
|
" disable logging by adding logging methods to config.ignore_logger_methods" |
|
|
) |
|
|
|
|
|
|
|
|
class ConstantLikeVariable(VariableTracker): |
|
|
"""self.value is a compile-time constant, but not a literal""" |
|
|
|
|
|
_error_prefix = "ConstantLikeVariable" |
|
|
try: |
|
|
from numpy import ( |
|
|
dtype as np_dtype, |
|
|
floating as np_floating, |
|
|
generic as np_generic, |
|
|
) |
|
|
except ImportError: |
|
|
np_floating = type("invalid_type", (), {}) |
|
|
np_dtype = type("invalid_type", (), {}) |
|
|
|
|
|
def __init__(self, value, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.value = value |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.value |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: list[VariableTracker], |
|
|
kwargs: dict[str, VariableTracker], |
|
|
) -> VariableTracker: |
|
|
try: |
|
|
|
|
|
cargs = [x.as_python_constant() for x in args] |
|
|
ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} |
|
|
except NotImplementedError: |
|
|
unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})") |
|
|
|
|
|
result = getattr(self.value, name)(*cargs, **ckwargs) |
|
|
|
|
|
if variables.ConstantVariable.is_literal(result): |
|
|
return variables.ConstantVariable.create(result) |
|
|
if isinstance(result, re.Match): |
|
|
return ConstantRegexMatchVariable(result) |
|
|
|
|
|
unimplemented(f"{self._error_prefix}.{name}() -> {result}") |
|
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: |
|
|
result = getattr(self.value, name) |
|
|
if isinstance(result, self.np_floating): |
|
|
result = float(result) |
|
|
if isinstance(result, self.np_dtype): |
|
|
return NumpyDTypeVariable(result) |
|
|
if isinstance(result, type) and issubclass(result, self.np_generic): |
|
|
|
|
|
return NumpyVariable(result) |
|
|
if variables.ConstantVariable.is_literal(result): |
|
|
return variables.ConstantVariable.create(result) |
|
|
return GetAttrVariable(self, name) |
|
|
|
|
|
|
|
|
class RegexPatternVariable(ConstantLikeVariable): |
|
|
_error_prefix = "re.Pattern" |
|
|
|
|
|
|
|
|
class ConstantRegexMatchVariable(ConstantLikeVariable): |
|
|
_error_prefix = "re.Match" |
|
|
|
|
|
|
|
|
class TorchVersionVariable(ConstantLikeVariable): |
|
|
_error_prefix = "torch.__version__" |
|
|
|
|
|
def __init__(self, **kwargs) -> None: |
|
|
kwargs.setdefault("value", torch.__version__) |
|
|
assert kwargs["value"] is torch.__version__ |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class NumpyTypeInfoVariable(ConstantLikeVariable): |
|
|
_error_prefix = "np.iinfo/np.finfo" |
|
|
|
|
|
|
|
|
class NumpyDTypeVariable(ConstantLikeVariable): |
|
|
_error_prefix = "np.dtype[...]" |
|
|
|
|
|
def as_proxy(self): |
|
|
"""Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: |
|
|
|
|
|
np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. |
|
|
This also handles unsupported things nicely (i.e. structured arrays and object arrays). |
|
|
""" |
|
|
return self.value.type.__name__ |
|
|
|
|
|
|
|
|
np_constant_collections_map = { |
|
|
tnp.finfo: NumpyTypeInfoVariable, |
|
|
tnp.iinfo: NumpyTypeInfoVariable, |
|
|
tnp.dtype: NumpyDTypeVariable, |
|
|
} |
|
|
|
|
|
|
|
|
class RandomClassVariable(VariableTracker): |
|
|
"""random.Random""" |
|
|
|
|
|
def __init__(self, **kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def call_function(self, tx: "InstructionTranslator", args, kwargs): |
|
|
if len(args) > 1: |
|
|
unimplemented("random.Random() with > 1 arg") |
|
|
elif kwargs: |
|
|
unimplemented("random.Random() with kwargs") |
|
|
seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] |
|
|
return RandomVariable( |
|
|
seed=seed, mutation_type=variables.base.ValueMutationNew() |
|
|
) |
|
|
|
|
|
|
|
|
class RandomVariable(VariableTracker): |
|
|
"""random.Random() |
|
|
|
|
|
Implemented by wrapping a VariableTracker around a random.Random object. |
|
|
The supported methods for the random.Random object cannot be overridden. |
|
|
Assumes that random objects behave the same given a set seed or state. |
|
|
""" |
|
|
|
|
|
_nonvar_fields = { |
|
|
"random", |
|
|
*VariableTracker._nonvar_fields, |
|
|
} |
|
|
|
|
|
_supported_fn_names = { |
|
|
"random", |
|
|
"randint", |
|
|
"randrange", |
|
|
"uniform", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
rand: Optional[random.Random] = None, |
|
|
seed: Optional[VariableTracker] = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
if rand is not None: |
|
|
assert self.is_supported_random_obj(rand) |
|
|
self.random = random.Random() |
|
|
self.random.setstate(rand.getstate()) |
|
|
else: |
|
|
seed = seed.as_python_constant() if seed is not None else None |
|
|
self.random = random.Random(seed) |
|
|
|
|
|
def python_type(self): |
|
|
return random.Random |
|
|
|
|
|
def as_python_constant(self): |
|
|
return self.random |
|
|
|
|
|
@staticmethod |
|
|
def is_supported_random_obj(val): |
|
|
if type(val) is not random.Random: |
|
|
return False |
|
|
for name in itertools.chain( |
|
|
RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") |
|
|
): |
|
|
if not hasattr(val, name): |
|
|
return False |
|
|
meth = getattr(val, name) |
|
|
if inspect.isbuiltin(meth): |
|
|
|
|
|
if meth != getattr(random.Random, name).__get__(val): |
|
|
return False |
|
|
else: |
|
|
if getattr(meth, "__func__", None) is not getattr(random.Random, name): |
|
|
return False |
|
|
return True |
|
|
|
|
|
@staticmethod |
|
|
def check_state(state): |
|
|
assert type(state) is tuple |
|
|
assert type(state[0]) is int |
|
|
assert type(state[1]) is tuple |
|
|
assert all(type(x) is int for x in state[1]) |
|
|
assert state[2] is None or type(state[2]) is float |
|
|
|
|
|
@staticmethod |
|
|
def wrap_state(state): |
|
|
RandomVariable.check_state(state) |
|
|
return variables.TupleVariable( |
|
|
[ |
|
|
variables.ConstantVariable.create(state[0]), |
|
|
variables.TupleVariable( |
|
|
[variables.ConstantVariable.create(x) for x in state[1]] |
|
|
), |
|
|
variables.ConstantVariable.create(state[2]), |
|
|
] |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def unwrap_state(state): |
|
|
state_obj = state.as_python_constant() |
|
|
RandomVariable.check_state(state_obj) |
|
|
return state_obj |
|
|
|
|
|
def call_method( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
name, |
|
|
args: list[VariableTracker], |
|
|
kwargs: dict[str, VariableTracker], |
|
|
) -> VariableTracker: |
|
|
if name == "seed": |
|
|
tx.output.side_effects.mutation(self) |
|
|
self.random.seed( |
|
|
*[x.as_python_constant() for x in args], |
|
|
**{key: val.as_python_constant() for key, val in kwargs.items()}, |
|
|
) |
|
|
return variables.ConstantVariable.create(None) |
|
|
elif name == "getstate": |
|
|
return self.wrap_state(self.random.getstate()) |
|
|
elif name == "setstate": |
|
|
tx.output.side_effects.mutation(self) |
|
|
self.random.setstate(self.unwrap_state(args[0])) |
|
|
return variables.ConstantVariable.create(None) |
|
|
elif name in self._supported_fn_names: |
|
|
tx.output.side_effects.mutation(self) |
|
|
state = self.random.getstate() |
|
|
|
|
|
def call_random_meth(*args, **kwargs): |
|
|
r = random.Random() |
|
|
r.setstate(state) |
|
|
return getattr(r, name)(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
getattr(self.random, name)( |
|
|
*[x.as_python_constant() for x in args], |
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()}, |
|
|
) |
|
|
|
|
|
return call_random_fn(tx, call_random_meth, args, kwargs) |
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen.add_push_null( |
|
|
lambda: codegen.extend_output( |
|
|
[ |
|
|
codegen.create_load_python_module(random), |
|
|
codegen.create_load_attr("Random"), |
|
|
] |
|
|
) |
|
|
) |
|
|
codegen.call_function(0, False) |
|
|
|
|
|
|
|
|
codegen.dup_top() |
|
|
codegen.load_attr("setstate") |
|
|
codegen(self.wrap_state(self.random.getstate())) |
|
|
codegen.call_function(1, True) |
|
|
codegen.pop_top() |
|
|
|
|
|
|
|
|
class WeakRefVariable(VariableTracker): |
|
|
@staticmethod |
|
|
def build(tx, weakref_value, **options): |
|
|
source = options.get("source", None) |
|
|
callback = weakref_value.__callback__ |
|
|
callback_source = source and AttrSource(source, "__callback__") |
|
|
callback_vt = VariableTracker.build(tx, callback, callback_source) |
|
|
referent = weakref_value() |
|
|
source = source and WeakRefCallSource(source) |
|
|
referent_vt = VariableTracker.build(tx, referent, source) |
|
|
options["source"] = source |
|
|
return WeakRefVariable(referent_vt, callback_vt, **options) |
|
|
|
|
|
def __init__(self, referent_vt, callback_vt, **options): |
|
|
super().__init__(**options) |
|
|
self.referent_vt = referent_vt |
|
|
self.callback_vt = callback_vt |
|
|
|
|
|
def call_function( |
|
|
self, |
|
|
tx: "InstructionTranslator", |
|
|
args: "list[VariableTracker]", |
|
|
kwargs: "dict[str, VariableTracker]", |
|
|
) -> "VariableTracker": |
|
|
return self.referent_vt |
|
|
|
|
|
def reconstruct(self, codegen: "PyCodegen"): |
|
|
codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) |
|
|
codegen(self.referent_vt) |
|
|
codegen(self.callback_vt) |
|
|
codegen.extend_output(create_call_function(2, False)) |
|
|
|