| import abc |
| import builtins |
| import importlib |
| import inspect |
| import logging |
| import pickle |
| import types |
| from dataclasses import dataclass |
| from typing import Any, Callable, Optional |
|
|
| import torch |
| import torch.fx |
| from torch._dynamo.precompile_context import PrecompileContext |
|
|
| from . import convert_frame |
| from .hooks import Hooks |
|
|
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class SerializableCallable(abc.ABC): |
| @classmethod |
| @abc.abstractmethod |
| def serialize_compile_artifacts(cls, fn: Any) -> bytes: |
| pass |
|
|
| @classmethod |
| @abc.abstractmethod |
| def deserialize_compile_artifacts(cls, data: bytes) -> Any: |
| pass |
|
|
|
|
| def bind_locals( |
| signature: inspect.Signature, *args: Any, **kwargs: Any |
| ) -> dict[str, Any]: |
| bound_arguments = signature.bind(*args, **kwargs) |
| bound_arguments.apply_defaults() |
| return bound_arguments.arguments |
|
|
|
|
| @dataclass |
| class CompileArtifacts: |
| signature: inspect.Signature |
| bytecode: types.CodeType |
| guard_manager: Optional[torch._dynamo.guards.GuardManagerWrapper] |
| guards_state: bytes |
| import_sources: dict[str, str] |
| backend_id: str |
| compiled_fn: SerializableCallable |
| original_code: types.CodeType |
| closure: Optional[tuple[Any, ...]] |
|
|
|
|
| @dataclass |
| class AOTCompiledFunction: |
| _artifacts: CompileArtifacts |
|
|
| def guard_check(self, *args: Any, **kwargs: Any) -> bool: |
| f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) |
| assert self._artifacts.guard_manager is not None |
| return self._artifacts.guard_manager.check(f_locals) |
|
|
| def __post_init__(self) -> None: |
| import_sources = { |
| alias: importlib.import_module(module_name) |
| for alias, module_name in self._artifacts.import_sources.items() |
| } |
| f_globals = { |
| **import_sources, |
| self._artifacts.backend_id: self._artifacts.compiled_fn, |
| } |
| self.fn = types.FunctionType( |
| self._artifacts.bytecode, f_globals, closure=self._artifacts.closure |
| ) |
|
|
| if self._artifacts.guard_manager is None: |
| guards_state = pickle.loads(self._artifacts.guards_state) |
| self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager( |
| self._artifacts.original_code, |
| guards_state.output_graph, |
| shape_code_parts=guards_state.shape_code_parts, |
| runtime_global_scope=f_globals, |
| ).guard_manager |
|
|
| def __call__(self, *args: Any, **kwargs: Any) -> Any: |
| assert self._artifacts.guard_manager is not None |
| if not self.guard_check(*args, **kwargs): |
| f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) |
| reason = str(self._artifacts.guard_manager.check_verbose(f_locals)) |
| raise RuntimeError(f"GuardManager check failed, reason: {reason}") |
| return self.fn(*args, **kwargs) |
|
|
| def save_compiled_function(self, path: str) -> None: |
| with open(path, "wb") as f: |
| f.write(type(self).serialize(self)) |
|
|
| @classmethod |
| def serialize(cls, fn: "AOTCompiledFunction") -> bytes: |
| from torch._dynamo.package import SerializedCode |
|
|
| state = fn._artifacts.__dict__.copy() |
| state["guard_manager"] = None |
| state["bytecode"] = SerializedCode.from_code_object(state["bytecode"]) |
| compiled_fn = state["compiled_fn"] |
| state["compiled_fn"] = ( |
| type(compiled_fn).deserialize_compile_artifacts, |
| type(compiled_fn).serialize_compile_artifacts(compiled_fn), |
| ) |
| state["original_code"] = SerializedCode.from_code_object(state["original_code"]) |
| return pickle.dumps(state) |
|
|
| @classmethod |
| def deserialize(cls, data: bytes) -> "AOTCompiledFunction": |
| from torch._dynamo.package import SerializedCode |
|
|
| state = pickle.loads(data) |
| state["bytecode"] = SerializedCode.to_code_object(state["bytecode"]) |
| deserializer, compiled_fn_state = state["compiled_fn"] |
| state["compiled_fn"] = deserializer(compiled_fn_state) |
| state["original_code"] = SerializedCode.to_code_object(state["original_code"]) |
|
|
| artifacts = CompileArtifacts(**state) |
| return cls(artifacts) |
|
|
|
|
| class BundledAOTAutogradSerializableCallable(SerializableCallable): |
| """ |
| Represents a serializable callable generated by compile_fx. |
| This class wraps around the compiled function generated by AOTAutograd. |
| |
| TODO: Instead of using PrecompileContext to grab it from AOTAutograd, |
| this object should be what's *returned* by aot_module_simplified. |
| We'll do that refactor in a later PR. |
| """ |
|
|
| def __init__(self, artifact: Any) -> None: |
| """ |
| Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form |
| of a compiled function generated by AOTAutograd. |
| """ |
|
|
| self.compiled_fn = artifact.after_deserialization() |
| self.data = artifact.content |
|
|
| def __getattr__(self, attr: Any) -> Any: |
| if hasattr(self, attr): |
| return getattr(super(), attr) |
| else: |
| return getattr(self.compiled_fn, attr) |
|
|
| @classmethod |
| def from_backend_id( |
| cls, backend_id: str |
| ) -> "BundledAOTAutogradSerializableCallable": |
| """ |
| Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable |
| that wraps around the compiled function generated by AOTAutograd. |
| """ |
| artifact = PrecompileContext.serialize_artifact_by_key(backend_id) |
| if artifact is None: |
| raise RuntimeError("No artifact found for backend_id: " + backend_id) |
| return cls(artifact) |
|
|
| @classmethod |
| def serialize_compile_artifacts( |
| cls, fn: "BundledAOTAutogradSerializableCallable" |
| ) -> bytes: |
| return fn.data |
|
|
| @classmethod |
| def deserialize_compile_artifacts(cls, data: bytes) -> Any: |
| from torch._functorch._aot_autograd.autograd_cache import ( |
| BundledAOTAutogradCacheArtifact, |
| ) |
|
|
| |
| |
| artifact = BundledAOTAutogradCacheArtifact("", data) |
| return cls(artifact) |
|
|
| def __call__(self, *args: Any, **kwargs: Any) -> Any: |
| return self.compiled_fn(*args, **kwargs) |
|
|
|
|
| def aot_compile_fullgraph( |
| model: Any, |
| example_inputs: tuple[tuple[Any, ...], dict[str, Any]], |
| hooks: Hooks, |
| backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable], |
| ) -> AOTCompiledFunction: |
| from torch._dynamo.guards import CheckFunctionManager |
| from torch._dynamo.utils import dynamo_timed, get_metrics_context |
| from torch._guards import compile_context, CompileContext, TracingContext |
|
|
| args, kwargs = example_inputs |
| if hasattr(model, "__self__"): |
| fn = model.__func__ |
| args = (model.__self__,) + args |
| elif inspect.isfunction(model): |
| fn = model |
| else: |
| raise RuntimeError(f"Unsupported model code type {model}") |
|
|
| signature = inspect.signature(fn) |
| f_locals = bind_locals(signature, *args, **kwargs) |
| if fn.__code__.co_freevars or fn.__closure__: |
| assert len(fn.__closure__) == len(fn.__code__.co_freevars) |
| f_locals.update( |
| { |
| name: cell.cell_contents |
| for name, cell in zip(fn.__code__.co_freevars, fn.__closure__) |
| } |
| ) |
|
|
| with ( |
| compile_context(CompileContext(convert_frame.get_compile_id({}))), |
| get_metrics_context(), |
| dynamo_timed("fullgraph_capture"), |
| ): |
| capture_output = convert_frame.fullgraph_capture( |
| convert_frame.FrameInfo( |
| fn.__code__, |
| fn.__globals__, |
| f_locals, |
| builtins.__dict__, |
| closure=fn.__closure__ or (), |
| ) |
| ) |
| dynamo_output = capture_output.dynamo_output |
|
|
| if not hooks.guard_filter_fn: |
| from torch._dynamo.types import GuardFilterEntry |
|
|
| def new_guard_filter_fn( |
| guard_entries: list[GuardFilterEntry], |
| ) -> list[bool]: |
| return [ |
| ( |
| not ( |
| g.is_global |
| or g.guard_type |
| in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES |
| ) |
| ) |
| for g in guard_entries |
| ] |
|
|
| hooks.guard_filter_fn = new_guard_filter_fn |
|
|
| check_fn = dynamo_output.build_guards( |
| fn.__code__, hooks=hooks, save=True, strict_error=True |
| ) |
|
|
| assert check_fn.guards_state is not None |
|
|
| backend_input = capture_output.backend_input |
| backend_input.graph_module._backend_id = backend_input.backend_id |
| output_graph = dynamo_output.tracer_output.output_graph |
| assert output_graph is not None |
| import_sources = output_graph.import_sources |
| with ( |
| torch._guards.tracing(TracingContext(backend_input.fake_mode)), |
| torch._functorch.config.patch("bundled_autograd_cache", True), |
| ): |
| compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs) |
|
|
| |
| |
| if isinstance(backend, torch._TorchCompileInductorWrapper): |
| compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id( |
| backend_input.backend_id |
| ) |
|
|
| if not isinstance(compiled_fn, SerializableCallable): |
| if hasattr(backend, "compiler_fn"): |
| compiler_fn = backend.compiler_fn |
| else: |
| compiler_fn = backend |
| raise RuntimeError( |
| f"Compiled function type {type(compiled_fn)} (produced " |
| + f"from backend {compiler_fn}) does not implement SerializableCallable." |
| ) |
|
|
| artifacts = CompileArtifacts( |
| signature=signature, |
| bytecode=dynamo_output.bytecode, |
| guard_manager=check_fn.guard_manager, |
| guards_state=check_fn.guards_state, |
| import_sources=import_sources, |
| backend_id=backend_input.backend_id, |
| compiled_fn=compiled_fn, |
| original_code=fn.__code__, |
| closure=fn.__closure__, |
| ) |
| aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts) |
| return aot_compiled_fn |
|
|