|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
import functools |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
import torch |
|
|
import torch.utils._pytree as pytree |
|
|
from torch.fx import Tracer, GraphModule |
|
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
|
from torch._dispatch.python import enable_python_dispatcher |
|
|
import torch.fx as fx |
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata |
|
|
from contextlib import contextmanager, nullcontext |
|
|
import inspect |
|
|
from dataclasses import dataclass |
|
|
import weakref |
|
|
import operator |
|
|
|
|
|
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode |
|
|
from torch._subclasses import FakeTensor |
|
|
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat |
|
|
from torch.fx import Proxy |
|
|
|
|
|
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy"] |
|
|
aten = torch.ops.aten |
|
|
prim = torch.ops.prim |
|
|
|
|
|
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} |
|
|
|
|
|
CONSTANT_NUMEL_LIMIT = 1 |
|
|
|
|
|
|
|
|
|
|
|
pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) |
|
|
|
|
|
def fake_signature(fn, nargs): |
|
|
"""FX gets confused by varargs, de-confuse it""" |
|
|
argnames = ",".join(f"arg{i}" for i in range(nargs)) |
|
|
return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) |
|
|
|
|
|
@contextmanager |
|
|
def decompose(decomposition_table): |
|
|
global CURRENT_DECOMPOSITION_TABLE |
|
|
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE |
|
|
CURRENT_DECOMPOSITION_TABLE = decomposition_table |
|
|
try: |
|
|
yield CURRENT_DECOMPOSITION_TABLE |
|
|
finally: |
|
|
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table |
|
|
|
|
|
|
|
|
proxy_slot = object() |
|
|
no_default = object() |
|
|
|
|
|
def set_proxy_slot(obj, tracer, proxy): |
|
|
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) |
|
|
assert isinstance(d, weakref.WeakKeyDictionary) |
|
|
d[tracer] = proxy |
|
|
|
|
|
def has_proxy_slot(obj, tracer): |
|
|
return get_proxy_slot(obj, tracer, False, lambda _: True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): |
|
|
d = obj.__dict__.get(proxy_slot) |
|
|
if not d: |
|
|
if default is no_default: |
|
|
raise KeyError(f"{obj} is not tracked with proxy for {tracer}") |
|
|
return default |
|
|
assert isinstance(d, weakref.WeakKeyDictionary) |
|
|
if tracer not in d: |
|
|
if default is no_default: |
|
|
raise KeyError(f"{obj} is not tracked with proxy for {tracer}") |
|
|
else: |
|
|
return default |
|
|
return transform(d[tracer]) |
|
|
|
|
|
|
|
|
def get_proxy_slots(obj): |
|
|
return obj.__dict__.get(proxy_slot) |
|
|
|
|
|
|
|
|
|
|
|
def get_proxy(obj): |
|
|
res = get_proxy_slots(obj) |
|
|
if res is None: |
|
|
return None |
|
|
vals = tuple(res.values()) |
|
|
assert len(vals) == 1 |
|
|
return vals[0] |
|
|
|
|
|
def has_proxy(obj): |
|
|
return get_proxy(obj) is not None |
|
|
|
|
|
def set_meta(proxy, val): |
|
|
if isinstance(val, FakeTensor): |
|
|
proxy.node.meta['val'] = val |
|
|
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) |
|
|
elif isinstance(val, PySymInt): |
|
|
proxy.node.meta['val'] = val |
|
|
elif isinstance(val, torch.Tensor): |
|
|
if not val.is_sparse: |
|
|
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) |
|
|
return proxy |
|
|
|
|
|
def thunkify(f, *args, **kwargs): |
|
|
""" |
|
|
Delays computation of f until it's called again |
|
|
Also caches the result |
|
|
""" |
|
|
return functools.lru_cache(1)(functools.partial(f, *args, **kwargs)) |
|
|
|
|
|
def track_tensor(tensor, proxy, *, constant, tracer): |
|
|
def try_set_proxy_slot(outer_s, proxy_callable, *args): |
|
|
assert callable(proxy_callable) |
|
|
if isinstance(outer_s, SymInt): |
|
|
inner_s = outer_s.get_pyobj() |
|
|
assert isinstance(inner_s, PySymInt) |
|
|
|
|
|
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, s in enumerate(tensor.shape): |
|
|
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i) |
|
|
|
|
|
for i, s in enumerate(tensor.stride()): |
|
|
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i) |
|
|
|
|
|
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x)) |
|
|
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x)) |
|
|
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) |
|
|
|
|
|
def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): |
|
|
def wrap_with_proxy(e, proxy, constant): |
|
|
if isinstance(e, torch.Tensor): |
|
|
track_tensor(e, proxy, tracer=tracer, constant=constant) |
|
|
set_meta(proxy, e) |
|
|
elif isinstance(e, list): |
|
|
|
|
|
for idx, ee in enumerate(e): |
|
|
wrap_with_proxy(ee, proxy[idx], get_constant(idx)) |
|
|
|
|
|
def get_constant(idx): |
|
|
if constant is None: |
|
|
return None |
|
|
else: |
|
|
return constant[idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(inner_res, tuple) or isinstance(inner_res, list): |
|
|
for idx, e in enumerate(inner_res): |
|
|
wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) |
|
|
elif isinstance(inner_res, torch.Tensor): |
|
|
wrap_with_proxy(inner_res, proxy_res, constant) |
|
|
|
|
|
return inner_res |
|
|
|
|
|
|
|
|
def maybe_disable_fake_tensor_mode(): |
|
|
|
|
|
|
|
|
mb_fake_mode = _get_current_dispatch_mode() |
|
|
if isinstance(mb_fake_mode, FakeTensorMode): |
|
|
return _pop_mode_temporarily() |
|
|
else: |
|
|
return nullcontext() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _ProxyTensor: |
|
|
proxy: Proxy |
|
|
constant: Optional[torch.Tensor] |
|
|
|
|
|
|
|
|
def fetch_sym_proxy(tracer): |
|
|
def inner(e): |
|
|
n = e.get_pyobj() |
|
|
if n.constant is not None: |
|
|
return n.constant |
|
|
else: |
|
|
|
|
|
return get_proxy_slot(n, tracer)() |
|
|
return inner |
|
|
|
|
|
|
|
|
def fetch_tensor_proxy(tracer): |
|
|
return lambda t: get_proxy_slot(t, tracer, t) |
|
|
|
|
|
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter) |
|
|
|
|
|
def proxy_call(proxy_mode, func, args, kwargs): |
|
|
def can_handle_tensor(x): |
|
|
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) |
|
|
|
|
|
|
|
|
|
|
|
if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)): |
|
|
return NotImplemented |
|
|
|
|
|
if func in CURRENT_DECOMPOSITION_TABLE: |
|
|
with proxy_mode: |
|
|
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs) |
|
|
if r is not NotImplemented: |
|
|
return r |
|
|
|
|
|
with proxy_mode: |
|
|
r = func.decompose(*args, **kwargs) |
|
|
if r is not NotImplemented: |
|
|
return r |
|
|
|
|
|
tracer = proxy_mode.tracer |
|
|
f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_constant = ( |
|
|
pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) |
|
|
|
|
|
|
|
|
and pytree.tree_all_only((SymInt, SymFloat), lambda _: False, (args, kwargs)) |
|
|
) |
|
|
|
|
|
if torch.Tag.data_dependent_output in func.tags: |
|
|
|
|
|
if all_constant: |
|
|
const_args, const_kwargs = pytree.tree_map_only( |
|
|
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) |
|
|
) |
|
|
with maybe_disable_fake_tensor_mode(): |
|
|
return func(*const_args, **const_kwargs) |
|
|
raise RuntimeError( |
|
|
"It appears that you're trying to get value out of a tracing tensor - erroring out! " |
|
|
"It's likely that this is caused by data-dependent control flow or similar." |
|
|
) |
|
|
proxy_args, proxy_kwargs = pytree.tree_map_only( |
|
|
(SymInt, SymFloat), |
|
|
fetch_sym_proxy(proxy_mode.tracer), |
|
|
pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func is torch.ops.aten.lift_fresh.default: |
|
|
func = torch.ops.aten.lift_fresh_copy.default |
|
|
|
|
|
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, |
|
|
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_": |
|
|
if isinstance(args[0], List): |
|
|
|
|
|
|
|
|
for i, a in enumerate(args[0]): |
|
|
a.proxy = proxy_out[0][i] |
|
|
else: |
|
|
args[0].proxy = proxy_out |
|
|
|
|
|
out = func(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) |
|
|
|
|
|
constant = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT: |
|
|
with maybe_disable_fake_tensor_mode(): |
|
|
constant = args[0].clone() |
|
|
elif ( |
|
|
torch.Tag.nondeterministic_seeded not in func.tags |
|
|
and all_constant |
|
|
and any_constant |
|
|
and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out) |
|
|
): |
|
|
|
|
|
with maybe_disable_fake_tensor_mode(): |
|
|
const_args, const_kwargs = pytree.tree_map_only( |
|
|
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) |
|
|
) |
|
|
constant = func(*const_args, **const_kwargs) |
|
|
else: |
|
|
constant = None |
|
|
|
|
|
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) |
|
|
return out |
|
|
|
|
|
|
|
|
class PythonKeyTracer(Tracer): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_module( |
|
|
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] |
|
|
) -> Any: |
|
|
return forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
def getattr(self, attr, attr_val, parameter_proxy_cache): |
|
|
return attr_val |
|
|
|
|
|
def create_arg(self, a: Any): |
|
|
if isinstance(a, torch.nn.Parameter): |
|
|
for n, p in self.root.named_parameters(): |
|
|
if a is p: |
|
|
return self.create_node('get_attr', n, (), {}) |
|
|
qualname: Optional[str] = None |
|
|
|
|
|
if not qualname: |
|
|
i = 0 |
|
|
while True: |
|
|
qualname = f'_param_constant{i}' |
|
|
if not hasattr(self.root, qualname): |
|
|
break |
|
|
i += 1 |
|
|
setattr(self.root, qualname, a) |
|
|
|
|
|
return self.create_node('get_attr', qualname, (), {}) |
|
|
elif isinstance(a, (SymInt, SymFloat)): |
|
|
assert a.get_pyobj().constant is not None |
|
|
return a.get_pyobj().constant |
|
|
return super().create_arg(a) |
|
|
|
|
|
|
|
|
def dispatch_trace( |
|
|
root: Union[torch.nn.Module, Callable], |
|
|
tracer: Tracer, |
|
|
concrete_args: Optional[Tuple[Any, ...]] = None, |
|
|
) -> GraphModule: |
|
|
graph = tracer.trace(root, concrete_args) |
|
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ |
|
|
return GraphModule(tracer.root, graph, name) |
|
|
|
|
|
|
|
|
def wrap_key(f, tensors, tracer): |
|
|
flat_tensors, tensors_spec = pytree.tree_flatten(tensors) |
|
|
|
|
|
@functools.wraps(f) |
|
|
def wrapped(*proxies): |
|
|
flat_proxies, proxies_spec = pytree.tree_flatten(proxies) |
|
|
assert len(flat_proxies) == len(flat_tensors) |
|
|
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) |
|
|
|
|
|
out = f(*tensors) |
|
|
out = pytree.tree_map_only( |
|
|
torch.Tensor, |
|
|
lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy), |
|
|
out |
|
|
) |
|
|
out = pytree.tree_map_only( |
|
|
(SymInt, SymFloat), |
|
|
lambda t: get_proxy_slot(t.get_pyobj(), tracer)(), |
|
|
out |
|
|
) |
|
|
return out |
|
|
|
|
|
return wrapped |
|
|
|
|
|
|
|
|
class ProxyTorchDispatchMode(TorchDispatchMode): |
|
|
def __init__(self, tracer): |
|
|
self.tracer = tracer |
|
|
self.enable_tracing = True |
|
|
self.sym_mode = ProxySymDispatchMode(tracer) |
|
|
self.trace_state = {} |
|
|
self._managers = [] |
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
|
with self.sym_mode.enable(False): |
|
|
return self.inner_torch_dispatch(func, types, args, kwargs) |
|
|
|
|
|
def __enter__(self): |
|
|
|
|
|
m = self.sym_mode.enable(True) |
|
|
self._managers.append(m) |
|
|
m.__enter__() |
|
|
return super().__enter__() |
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
|
m = self._managers.pop() |
|
|
|
|
|
b = super().__exit__(exc_type, exc_value, traceback) |
|
|
if not b: |
|
|
return m.__exit__(exc_type, exc_value, traceback) |
|
|
else: |
|
|
return m.__exit__(None, None, None) |
|
|
|
|
|
def inner_torch_dispatch(self, func, types, args=(), kwargs=None): |
|
|
if not self.enable_tracing: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
if func in [prim.device.default]: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
out = proxy_call(self, func, args, kwargs) |
|
|
return out |
|
|
|
|
|
|
|
|
SymInt = torch.SymIntNode |
|
|
SymFloat = torch.SymFloatNode |
|
|
|
|
|
|
|
|
class ProxySymDispatchMode(SymDispatchMode): |
|
|
def __init__(self, tracer): |
|
|
super().__init__() |
|
|
self.tracer = tracer |
|
|
|
|
|
|
|
|
|
|
|
self.enable_tracing = True |
|
|
|
|
|
@contextmanager |
|
|
def enable(self, b): |
|
|
old = self.enable_tracing |
|
|
self.enable_tracing = b |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
self.enable_tracing = old |
|
|
|
|
|
def _compute_proxy(self, func, args, out): |
|
|
n_args = tuple( |
|
|
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant |
|
|
if isinstance(a, (PySymInt, PySymFloat)) else a |
|
|
for a in args |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
n_out = self.tracer.create_node("call_function", func, n_args, {}) |
|
|
p_out = fx.Proxy(n_out, self.tracer) |
|
|
set_meta(p_out, out) |
|
|
return p_out |
|
|
|
|
|
def __sym_dispatch__(self, func, types, args, kwargs): |
|
|
if not self.enable_tracing: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
|
if func == operator.mul: |
|
|
if isinstance(args[1], PySymInt) and args[1].constant == 1: |
|
|
return args[0] |
|
|
elif isinstance(args[0], PySymInt) and args[0].constant == 1: |
|
|
return args[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert not kwargs |
|
|
out = func(*args, **kwargs) |
|
|
assert isinstance(out, (PySymInt, PySymFloat)), f"{func}(*{args}, **{kwargs}) = {out}" |
|
|
|
|
|
|
|
|
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) |
|
|
set_proxy_slot(out, self.tracer, p_out_thunk) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecompositionInterpreter(torch.fx.Interpreter): |
|
|
def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs): |
|
|
super().__init__(module, **kwargs) |
|
|
self.new_graph = new_graph |
|
|
self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph) |
|
|
self.decomposition_table = decomposition_table |
|
|
if self.decomposition_table is None: |
|
|
self.decomposition_table = {} |
|
|
self.mode = ProxyTorchDispatchMode(self.tracer) |
|
|
|
|
|
def placeholder(self, target, args, kwargs): |
|
|
out = super().placeholder(target, args, kwargs) |
|
|
proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer) |
|
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) |
|
|
|
|
|
return out |
|
|
|
|
|
def get_attr(self, target, args, kwargs): |
|
|
out = super().get_attr(target, args, kwargs) |
|
|
proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer) |
|
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def output(self, target, args, kwargs): |
|
|
out = super().output(target, args, kwargs) |
|
|
|
|
|
def unwrap(e): |
|
|
return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) |
|
|
self.new_graph.output(pytree.tree_map(unwrap, out)) |
|
|
return out |
|
|
|
|
|
def run(self, *args, **kwargs): |
|
|
|
|
|
|
|
|
with decompose(self.decomposition_table), self.mode: |
|
|
return super().run(*args, **kwargs) |
|
|
|
|
|
|
|
|
def wrapper_and_args_for_make_fx(func, args, kwargs): |
|
|
|
|
|
|
|
|
flat_args, spec = pytree.tree_flatten((args, kwargs)) |
|
|
|
|
|
def wrapped(flat_args): |
|
|
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) |
|
|
return func(*fn_args, **fn_kwargs) |
|
|
return wrapped, flat_args |
|
|
|
|
|
@contextmanager |
|
|
def disable_autocast_cache(): |
|
|
old_value = torch.is_autocast_cache_enabled() |
|
|
torch.set_autocast_cache_enabled(False) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
torch.set_autocast_cache_enabled(old_value) |
|
|
|
|
|
|
|
|
def make_fx(f, decomposition_table=None, tracing_mode="real"): |
|
|
assert tracing_mode in ["real", "fake", "symbolic"] |
|
|
|
|
|
if decomposition_table is None: |
|
|
decomposition_table = {} |
|
|
|
|
|
@functools.wraps(f) |
|
|
def wrapped(*args): |
|
|
phs = pytree.tree_map(lambda _: fx.PH, args) |
|
|
fx_tracer = PythonKeyTracer() |
|
|
fake_tensor_mode: Any = nullcontext() |
|
|
if tracing_mode == "real": |
|
|
fake_tensor_mode = nullcontext() |
|
|
elif tracing_mode == "fake": |
|
|
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) |
|
|
elif tracing_mode == "symbolic": |
|
|
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) |
|
|
else: |
|
|
raise AssertionError(f"Unexpected tracing type: {tracing_mode}") |
|
|
|
|
|
python_dispatcher_mode: Any = nullcontext() |
|
|
if tracing_mode == "symbolic": |
|
|
python_dispatcher_mode = enable_python_dispatcher() |
|
|
|
|
|
proxy_mode = ProxyTorchDispatchMode(fx_tracer) |
|
|
|
|
|
def wrap_fake_concrete(x): |
|
|
if isinstance(x, torch.Tensor): |
|
|
return fake_tensor_mode.from_tensor(x) |
|
|
|
|
|
return x |
|
|
|
|
|
shape_env = None |
|
|
if tracing_mode == "symbolic": |
|
|
shape_env = ShapeEnv() |
|
|
sym_mode = proxy_mode.sym_mode |
|
|
|
|
|
|
|
|
def wrap_fake_symbolic(x): |
|
|
if isinstance(x, torch.Tensor): |
|
|
return fake_tensor_mode.from_tensor(x, shape_env=shape_env) |
|
|
return x |
|
|
|
|
|
wrap_fn_map = { |
|
|
"real": lambda x: x, |
|
|
"fake": wrap_fake_concrete, |
|
|
"symbolic": wrap_fake_symbolic, |
|
|
} |
|
|
args = pytree.tree_map(wrap_fn_map[tracing_mode], args) |
|
|
|
|
|
if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS: |
|
|
|
|
|
|
|
|
func = fake_signature(f, len(phs)) |
|
|
else: |
|
|
func = f |
|
|
|
|
|
|
|
|
|
|
|
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \ |
|
|
sym_mode, proxy_mode, disable_autocast_cache(): |
|
|
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs)) |
|
|
|
|
|
|
|
|
if tracing_mode == "symbolic": |
|
|
t.shape_env = shape_env |
|
|
return t |
|
|
|
|
|
return wrapped |
|
|
|
|
|
|
|
|
def get_torch_dispatch_modes(): |
|
|
return torch.utils._python_dispatch._get_current_dispatch_mode_stack() |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def disable_proxy_modes_tracing(): |
|
|
|
|
|
modes = get_torch_dispatch_modes() |
|
|
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] |
|
|
olds = [m.enable_tracing for m in proxy_tensor_modes] |
|
|
for proxy_mode in proxy_tensor_modes: |
|
|
proxy_mode.enable_tracing = False |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
for proxy_mode, old in zip(proxy_tensor_modes, olds): |
|
|
proxy_mode.enable_tracing = old |
|
|
|
|
|
|
|
|
def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): |
|
|
"""A helper function used to get the GraphModule for the given func. |
|
|
|
|
|
It's expected to be used in the ProxyTensor tracing context. |
|
|
It detaches the args and kwargs from the current tracer so that the trace of |
|
|
the current graph module can be created without any side-effects. |
|
|
""" |
|
|
wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) |
|
|
|
|
|
with disable_proxy_modes_tracing(): |
|
|
gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args) |
|
|
return gm |
|
|
|