|
|
import dataclasses
|
|
|
import importlib
|
|
|
import io
|
|
|
import pickle
|
|
|
from abc import abstractmethod
|
|
|
from typing import Any, Callable, NewType, Optional, TypeVar, Union
|
|
|
from typing_extensions import override, Self
|
|
|
|
|
|
import torch
|
|
|
import torch.utils._pytree as pytree
|
|
|
from torch._guards import TracingContext
|
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
|
|
|
from torch._subclasses.meta_utils import (
|
|
|
MetaConverter,
|
|
|
MetaTensorDesc,
|
|
|
MetaTensorDescriber,
|
|
|
)
|
|
|
from torch.fx.experimental.sym_node import SymNode
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
|
|
|
|
|
|
_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
|
|
|
|
|
|
|
|
|
def _ops_filter_safe(name: str) -> bool:
|
|
|
"""
|
|
|
An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in
|
|
|
ones where it will be possible to unpickle on any machine which has PyTorch.
|
|
|
"""
|
|
|
|
|
|
return name.startswith(
|
|
|
(
|
|
|
"torch.ops.aten",
|
|
|
"torch.ops.fbgemm",
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
class Options:
|
|
|
|
|
|
|
|
|
ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe
|
|
|
|
|
|
|
|
|
class GraphPickler(pickle.Pickler):
|
|
|
"""
|
|
|
GraphPickler is a Pickler which helps pickling fx graph - in particular
|
|
|
GraphModule.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None:
|
|
|
super().__init__(file)
|
|
|
self.options = options or Options()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._unpickle_state = _UnpickleStateToken(object())
|
|
|
|
|
|
|
|
|
|
|
|
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
|
|
|
|
|
|
@override
|
|
|
def reducer_override(
|
|
|
self, obj: object
|
|
|
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(obj, FakeTensor):
|
|
|
return _TensorPickleData.reduce_helper(self, obj)
|
|
|
elif isinstance(obj, torch.fx.GraphModule):
|
|
|
return _GraphModulePickleData.reduce_helper(self, obj)
|
|
|
elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
|
|
|
return _OpPickleData.reduce_helper(self, obj)
|
|
|
elif isinstance(obj, ShapeEnv):
|
|
|
return _ShapeEnvPickleData.reduce_helper(self, obj)
|
|
|
elif isinstance(obj, torch.SymInt):
|
|
|
return _SymNodePickleData.reduce_helper(self, obj)
|
|
|
elif isinstance(obj, torch._guards.TracingContext):
|
|
|
return _TracingContextPickleData.reduce_helper(self, obj)
|
|
|
else:
|
|
|
|
|
|
assert not isinstance(obj, torch.fx.Node)
|
|
|
if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
|
|
|
return reduce
|
|
|
|
|
|
|
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
@override
|
|
|
def persistent_id(self, obj: object) -> Optional[str]:
|
|
|
if obj is self._unpickle_state:
|
|
|
return "unpickle_state"
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
@classmethod
|
|
|
def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes:
|
|
|
"""
|
|
|
Pickle an object.
|
|
|
"""
|
|
|
with io.BytesIO() as stream:
|
|
|
pickler = cls(stream, options)
|
|
|
pickler.dump(obj)
|
|
|
return stream.getvalue()
|
|
|
|
|
|
@staticmethod
|
|
|
def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
|
|
|
"""
|
|
|
Unpickle an object.
|
|
|
"""
|
|
|
state = _UnpickleState(fake_mode)
|
|
|
with io.BytesIO(data) as stream:
|
|
|
unpickler = _GraphUnpickler(stream, state)
|
|
|
return unpickler.load()
|
|
|
|
|
|
|
|
|
class _UnpickleState:
|
|
|
def __init__(self, fake_mode: FakeTensorMode) -> None:
|
|
|
self.fake_mode = fake_mode
|
|
|
self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_UnpickleStateToken = NewType("_UnpickleStateToken", object)
|
|
|
|
|
|
|
|
|
class _GraphUnpickler(pickle.Unpickler):
|
|
|
def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
|
|
|
super().__init__(stream)
|
|
|
self._unpickle_state = unpickle_state
|
|
|
|
|
|
@override
|
|
|
def persistent_load(self, pid: object) -> object:
|
|
|
if pid == "unpickle_state":
|
|
|
return self._unpickle_state
|
|
|
else:
|
|
|
raise pickle.UnpicklingError("Invalid persistent ID")
|
|
|
|
|
|
|
|
|
class _ShapeEnvPickleData:
|
|
|
data: dict[str, object]
|
|
|
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, obj: ShapeEnv
|
|
|
) -> tuple[
|
|
|
Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken]
|
|
|
]:
|
|
|
return cls.unpickle, (cls(obj), pickler._unpickle_state)
|
|
|
|
|
|
def __init__(self, env: ShapeEnv) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
assert not env._translation_validation_enabled
|
|
|
self.data = env.__dict__.copy()
|
|
|
del self.data["tracked_fakes"]
|
|
|
del self.data["fake_tensor_cache"]
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
|
|
|
|
|
|
assert unpickle_state.fake_mode
|
|
|
assert unpickle_state.fake_mode.shape_env
|
|
|
|
|
|
for k, v in self.data.items():
|
|
|
setattr(unpickle_state.fake_mode.shape_env, k, v)
|
|
|
|
|
|
return unpickle_state.fake_mode.shape_env
|
|
|
|
|
|
|
|
|
class _SymNodePickleData:
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls,
|
|
|
pickler: GraphPickler,
|
|
|
obj: _SymNodeT,
|
|
|
) -> tuple[
|
|
|
Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken]
|
|
|
]:
|
|
|
args = (cls(obj.node), pickler._unpickle_state)
|
|
|
if isinstance(obj, torch.SymInt):
|
|
|
return _SymNodePickleData.unpickle_sym_int, args
|
|
|
else:
|
|
|
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
|
|
|
|
|
|
def __init__(self, node: SymNode) -> None:
|
|
|
self.expr = node._expr
|
|
|
self.shape_env = node.shape_env
|
|
|
self.pytype = node.pytype
|
|
|
self.hint = node._hint
|
|
|
|
|
|
def _to_sym_node(self) -> SymNode:
|
|
|
from torch.fx.experimental.sym_node import SymNode
|
|
|
|
|
|
assert self.shape_env is not None
|
|
|
return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
|
|
|
|
|
|
def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
|
|
|
return torch.SymInt(self._to_sym_node())
|
|
|
|
|
|
|
|
|
class _TensorPickleData:
|
|
|
metadata: MetaTensorDesc[FakeTensor]
|
|
|
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, obj: FakeTensor
|
|
|
) -> tuple[
|
|
|
Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken]
|
|
|
]:
|
|
|
return cls.unpickle, (
|
|
|
cls(pickler._meta_tensor_describer, obj),
|
|
|
pickler._unpickle_state,
|
|
|
)
|
|
|
|
|
|
def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = describer.describe_tensor(t)
|
|
|
|
|
|
|
|
|
|
|
|
assert not metadata.view_func or isinstance(
|
|
|
metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
|
|
|
)
|
|
|
self.metadata = dataclasses.replace(metadata, fake_mode=None)
|
|
|
|
|
|
|
|
|
for k in MetaTensorDesc._UNSERIALIZABLE:
|
|
|
if k in ("fake_mode", "view_func"):
|
|
|
continue
|
|
|
assert getattr(self.metadata, k) is None, (
|
|
|
f"not None: {k}: {getattr(self.metadata, k)}"
|
|
|
)
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
|
|
|
|
|
|
metadata = dataclasses.replace(
|
|
|
self.metadata,
|
|
|
fake_mode=unpickle_state.fake_mode,
|
|
|
)
|
|
|
|
|
|
def with_fake(
|
|
|
make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
|
|
|
) -> FakeTensor:
|
|
|
with no_dispatch():
|
|
|
return FakeTensor(
|
|
|
unpickle_state.fake_mode,
|
|
|
make_meta_t(),
|
|
|
device,
|
|
|
)
|
|
|
|
|
|
return unpickle_state.meta_converter.meta_tensor(
|
|
|
metadata,
|
|
|
unpickle_state.fake_mode.shape_env,
|
|
|
with_fake,
|
|
|
None,
|
|
|
None,
|
|
|
)
|
|
|
|
|
|
|
|
|
class _TorchNumpyPickleData:
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, obj: object
|
|
|
) -> Optional[
|
|
|
tuple[
|
|
|
Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken]
|
|
|
]
|
|
|
]:
|
|
|
if data := cls.from_object(obj):
|
|
|
return (cls.unpickle, (data, pickler._unpickle_state))
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def __init__(self, mod: str, name: str) -> None:
|
|
|
self.mod = mod
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
|
|
|
np = getattr(importlib.import_module(self.mod), self.name)
|
|
|
return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
|
|
|
|
|
|
@classmethod
|
|
|
def from_object(cls, tnp: object) -> Optional[Self]:
|
|
|
if not callable(tnp):
|
|
|
return None
|
|
|
|
|
|
tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
|
|
|
try:
|
|
|
if not (np := tnp_to_np.get(tnp)):
|
|
|
return None
|
|
|
except TypeError:
|
|
|
return None
|
|
|
|
|
|
if not (mod := getattr(np, "__module__", None)):
|
|
|
mod = "numpy"
|
|
|
|
|
|
if not (name := getattr(np, "__name__", None)):
|
|
|
return None
|
|
|
|
|
|
assert np == getattr(importlib.import_module(mod), name)
|
|
|
return cls(mod, name)
|
|
|
|
|
|
|
|
|
class _GraphModulePickleData:
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, obj: torch.fx.GraphModule
|
|
|
) -> tuple[
|
|
|
Callable[[Self, _UnpickleState], torch.fx.GraphModule],
|
|
|
tuple[Self, _UnpickleStateToken],
|
|
|
]:
|
|
|
return cls.unpickle, (
|
|
|
cls(obj, pickler.options),
|
|
|
pickler._unpickle_state,
|
|
|
)
|
|
|
|
|
|
def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None:
|
|
|
|
|
|
if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
|
|
|
_python_code = gm._real_recompile()
|
|
|
else:
|
|
|
_python_code = gm.recompile()
|
|
|
self.gm_dict = gm.__dict__.copy()
|
|
|
del self.gm_dict["_graph"]
|
|
|
self.graph = _GraphPickleData(gm._graph, options)
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
|
|
|
gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
|
|
|
gm.__dict__ = self.gm_dict
|
|
|
gm._graph = self.graph.unpickle(gm, unpickle_state)
|
|
|
return gm
|
|
|
|
|
|
|
|
|
class _NodePickleData:
|
|
|
def __init__(
|
|
|
self,
|
|
|
node: torch.fx.Node,
|
|
|
mapping: dict[torch.fx.Node, "_NodePickleData"],
|
|
|
options: Options,
|
|
|
) -> None:
|
|
|
self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
|
|
|
self.kwargs = pytree.tree_map_only(
|
|
|
torch.fx.Node, lambda n: mapping[n], node.kwargs
|
|
|
)
|
|
|
|
|
|
self.name = node.name
|
|
|
self.op = node.op
|
|
|
self.target = _OpPickleData.pickle(node.target, options)
|
|
|
|
|
|
|
|
|
self.type = node.type
|
|
|
|
|
|
|
|
|
|
|
|
self.meta = node.meta
|
|
|
|
|
|
def unpickle(
|
|
|
self,
|
|
|
graph: torch.fx.Graph,
|
|
|
mapping: dict["_NodePickleData", torch.fx.Node],
|
|
|
unpickle_state: _UnpickleState,
|
|
|
) -> torch.fx.Node:
|
|
|
args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
|
|
|
kwargs = pytree.tree_map_only(
|
|
|
_NodePickleData, lambda n: mapping[n], self.kwargs
|
|
|
)
|
|
|
target = self.target.unpickle(unpickle_state)
|
|
|
assert callable(target) or isinstance(target, str)
|
|
|
node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
|
|
|
node.meta = self.meta
|
|
|
return node
|
|
|
|
|
|
|
|
|
class _OpPickleData:
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, op: object
|
|
|
) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]:
|
|
|
result = cls.pickle(op, pickler.options)
|
|
|
return (result.unpickle, (pickler._unpickle_state,))
|
|
|
|
|
|
@classmethod
|
|
|
def pickle(cls, op: object, options: Options) -> "_OpPickleData":
|
|
|
if isinstance(op, str):
|
|
|
return _OpStrPickleData(op)
|
|
|
|
|
|
name = torch.fx.Node._pretty_print_target(op)
|
|
|
if isinstance(op, torch._ops.OpOverload):
|
|
|
return cls._pickle_op(name, _OpOverloadPickleData, options)
|
|
|
elif isinstance(op, torch._ops.OpOverloadPacket):
|
|
|
return cls._pickle_op(name, _OpOverloadPacketPickleData, options)
|
|
|
elif name.startswith(("builtins.", "math.", "torch.")):
|
|
|
root, detail = name.split(".", 1)
|
|
|
return _OpBuiltinPickleData(root, detail)
|
|
|
elif name.startswith("operator."):
|
|
|
_, detail = name.split(".", 1)
|
|
|
return _OpOperatorPickleData(detail)
|
|
|
else:
|
|
|
|
|
|
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
|
|
|
|
|
|
@staticmethod
|
|
|
def _pickle_op(
|
|
|
name: str,
|
|
|
datacls: Union[
|
|
|
type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"]
|
|
|
],
|
|
|
options: Options,
|
|
|
) -> "_OpPickleData":
|
|
|
if (ops_filter := options.ops_filter) and not ops_filter(name):
|
|
|
from torch._inductor.codecache import BypassFxGraphCache
|
|
|
|
|
|
raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
|
|
|
return datacls(name)
|
|
|
|
|
|
@abstractmethod
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
|
|
pass
|
|
|
|
|
|
@classmethod
|
|
|
def _lookup_global_by_name(cls, name: str) -> object:
|
|
|
"""
|
|
|
Like `globals()[name]` but supports dotted names.
|
|
|
"""
|
|
|
if "." in name:
|
|
|
mod, rest = name.split(".", 1)
|
|
|
root = globals()[mod]
|
|
|
return cls._getattr_by_name(root, rest)
|
|
|
else:
|
|
|
return globals()[name]
|
|
|
|
|
|
@staticmethod
|
|
|
def _getattr_by_name(root: object, name: str) -> object:
|
|
|
"""
|
|
|
Like `getattr(root, name)` but supports dotted names.
|
|
|
"""
|
|
|
while "." in name:
|
|
|
mod, name = name.split(".", 1)
|
|
|
root = getattr(root, mod)
|
|
|
return getattr(root, name)
|
|
|
|
|
|
|
|
|
class _OpStrPickleData(_OpPickleData):
|
|
|
def __init__(self, name: str) -> None:
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> str:
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
class _OpOverloadPickleData(_OpPickleData):
|
|
|
def __init__(self, name: str) -> None:
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
|
|
|
obj = self._lookup_global_by_name(self.name)
|
|
|
assert isinstance(obj, torch._ops.OpOverload)
|
|
|
return obj
|
|
|
|
|
|
|
|
|
class _OpOverloadPacketPickleData(_OpPickleData):
|
|
|
def __init__(self, name: str) -> None:
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
|
|
|
obj = self._lookup_global_by_name(self.name)
|
|
|
assert isinstance(obj, torch._ops.OpOverloadPacket)
|
|
|
return obj
|
|
|
|
|
|
|
|
|
class _OpBuiltinPickleData(_OpPickleData):
|
|
|
def __init__(self, root: str, name: str) -> None:
|
|
|
self.root = root
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
|
|
if self.root == "builtins":
|
|
|
return __builtins__.get(self.name)
|
|
|
elif self.root == "math":
|
|
|
import math
|
|
|
|
|
|
return self._getattr_by_name(math, self.name)
|
|
|
elif self.root == "torch":
|
|
|
return self._getattr_by_name(torch, self.name)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
class _OpOperatorPickleData(_OpPickleData):
|
|
|
def __init__(self, name: str) -> None:
|
|
|
self.name = name
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
|
|
import operator
|
|
|
|
|
|
return self._getattr_by_name(operator, self.name)
|
|
|
|
|
|
|
|
|
class _GraphPickleData:
|
|
|
def __init__(self, graph: torch.fx.Graph, options: Options) -> None:
|
|
|
self.tracer_cls = graph._tracer_cls
|
|
|
self.tracer_extras = graph._tracer_extras
|
|
|
|
|
|
nodes: dict[torch.fx.Node, _NodePickleData] = {}
|
|
|
for node in graph.nodes:
|
|
|
nodes[node] = _NodePickleData(node, nodes, options)
|
|
|
self.nodes = tuple(nodes.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unpickle(
|
|
|
self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
|
|
|
) -> torch.fx.Graph:
|
|
|
graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
|
|
|
|
|
|
nodes: dict[_NodePickleData, torch.fx.Node] = {}
|
|
|
for nd in self.nodes:
|
|
|
nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
class _TracingContextPickleData:
|
|
|
@classmethod
|
|
|
def reduce_helper(
|
|
|
cls, pickler: GraphPickler, obj: torch._guards.TracingContext
|
|
|
) -> tuple[
|
|
|
Callable[[Self, _UnpickleState], torch._guards.TracingContext],
|
|
|
tuple[Self, _UnpickleStateToken],
|
|
|
]:
|
|
|
return (
|
|
|
cls.unpickle,
|
|
|
(
|
|
|
cls(obj),
|
|
|
pickler._unpickle_state,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
def __init__(self, context: TracingContext) -> None:
|
|
|
|
|
|
self.module_context = context.module_context
|
|
|
self.frame_summary_stack = context.frame_summary_stack
|
|
|
self.loc_in_frame = context.loc_in_frame
|
|
|
self.aot_graph_name = context.aot_graph_name
|
|
|
self.params_flat = context.params_flat
|
|
|
self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
|
|
|
self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
|
|
|
self.output_strides = context.output_strides
|
|
|
self.force_unspec_int_unbacked_size_like = (
|
|
|
context.force_unspec_int_unbacked_size_like
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
|
|
|
context = TracingContext(unpickle_state.fake_mode)
|
|
|
context.module_context = self.module_context
|
|
|
context.frame_summary_stack = self.frame_summary_stack
|
|
|
context.loc_in_frame = self.loc_in_frame
|
|
|
context.aot_graph_name = self.aot_graph_name
|
|
|
context.params_flat = self.params_flat
|
|
|
context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
|
|
|
context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
|
|
|
context.output_strides = self.output_strides
|
|
|
context.force_unspec_int_unbacked_size_like = (
|
|
|
self.force_unspec_int_unbacked_size_like
|
|
|
)
|
|
|
return context
|
|
|
|