| | """ |
| | This module provides common utilities and base classes for TorchDynamo backends. |
| | |
| | Key components: |
| | - AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends |
| | - Backend utilities for handling: |
| | - Fake tensor conversion |
| | - Device/dtype detection from inputs |
| | - Memory efficient fusion |
| | - Graph flattening |
| | - Common compiler configurations |
| | |
| | The utilities here are used by various backend implementations to handle |
| | common operations and provide consistent behavior across different backends. |
| | AOT autograd functionality is particularly important as it enables ahead-of-time |
| | optimization of both forward and backward passes. |
| | """ |
| |
|
| | import contextlib |
| | import functools |
| | import logging |
| | from collections.abc import Iterable |
| | from typing import Any, Callable |
| | from typing_extensions import ParamSpec, TypeVar |
| | from unittest.mock import patch |
| |
|
| | import torch |
| | from torch._dynamo import disable |
| | from torch._dynamo.exc import TensorifyScalarRestartAnalysis |
| | from torch._dynamo.utils import counters, defake, flatten_graph_inputs |
| | from torch._functorch.aot_autograd import ( |
| | aot_module_simplified, |
| | SerializableAOTDispatchCompiler, |
| | ) |
| | from torch.utils._python_dispatch import _disable_current_modes |
| |
|
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | P = ParamSpec("P") |
| | R = TypeVar("R") |
| |
|
| |
|
| | class AotAutograd: |
| | def __init__(self, **kwargs: Any) -> None: |
| | self.__name__ = "compiler_fn" |
| | self.kwargs = kwargs |
| |
|
| | def __call__( |
| | self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any |
| | ) -> Callable[..., Any]: |
| | if kwargs: |
| | log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) |
| |
|
| | if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): |
| | return flatten_graph_inputs( |
| | gm, |
| | example_inputs, |
| | self, |
| | ) |
| |
|
| | |
| | if callable(self.kwargs.get("decompositions")): |
| | self.kwargs["decompositions"] = self.kwargs["decompositions"]() |
| |
|
| | |
| | counters["aot_autograd"]["total"] += 1 |
| | use_fallback = False |
| |
|
| | if use_fallback: |
| | log.debug("Unable to use AOT Autograd because graph has mutation") |
| | counters["aot_autograd"]["not_ok"] += 1 |
| | return gm |
| |
|
| | def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]: |
| | def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: |
| | |
| | |
| | |
| | |
| | return disable( |
| | disable( |
| | bw_compiler_fn, reason="do not trace backward compiler function" |
| | )(*args, **kwargs), |
| | reason="do not trace generated backwards pass", |
| | ) |
| |
|
| | return _wrapped_bw_compiler |
| |
|
| | bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] |
| |
|
| | if isinstance(bw_compiler, SerializableAOTDispatchCompiler): |
| | bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn) |
| | else: |
| | bw_compiler = wrap_bw_compiler(bw_compiler) |
| |
|
| | self.kwargs["bw_compiler"] = bw_compiler |
| | self.kwargs["inference_compiler"] = ( |
| | self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] |
| | ) |
| |
|
| | from functorch.compile import nop |
| | from torch._inductor.debug import enable_aot_logging |
| |
|
| | |
| | |
| | if self.kwargs.get("fw_compiler", None) == nop: |
| | patch_config: contextlib.AbstractContextManager[Any] = patch( |
| | "functorch.compile.config.debug_assert", True |
| | ) |
| | else: |
| | patch_config = contextlib.nullcontext() |
| |
|
| | try: |
| | |
| | with enable_aot_logging(), patch_config: |
| | cg = aot_module_simplified(gm, example_inputs, **self.kwargs) |
| | counters["aot_autograd"]["ok"] += 1 |
| | return disable(cg, reason="do not trace AOT-compiled graph") |
| | except TensorifyScalarRestartAnalysis: |
| | raise |
| | except Exception: |
| | counters["aot_autograd"]["not_ok"] += 1 |
| | raise |
| |
|
| |
|
| | def aot_autograd(**kwargs: Any) -> AotAutograd: |
| | return AotAutograd(**kwargs) |
| |
|
| |
|
| | def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]: |
| | from functorch.compile import ( |
| | default_decompositions, |
| | min_cut_rematerialization_partition, |
| | ts_compile, |
| | ) |
| |
|
| | kwargs = { |
| | |
| | "fw_compiler": ts_compile, |
| | "bw_compiler": ts_compile, |
| | "partition_fn": min_cut_rematerialization_partition, |
| | } |
| |
|
| | if use_decomps: |
| | kwargs["decompositions"] = default_decompositions |
| |
|
| | return kwargs |
| |
|
| |
|
| | def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any: |
| | """ |
| | Decorator for backends that need real inputs. We swap out fake |
| | tensors for zero tensors. |
| | """ |
| |
|
| | @functools.wraps(fn) |
| | def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any: |
| | with _disable_current_modes(): |
| | inputs = list(map(defake, inputs)) |
| | return fn(model, inputs, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device: |
| | for x in example_inputs: |
| | if hasattr(x, "device"): |
| | return x.device |
| | return torch.device("cpu") |
| |
|
| |
|
| | def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype: |
| | for x in example_inputs: |
| | if hasattr(x, "dtype"): |
| | return x.dtype |
| | return torch.float32 |
| |
|