|
|
|
|
|
|
|
|
import itertools
|
|
|
from collections.abc import KeysView, Sequence
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
from functools import partial, wraps
|
|
|
from typing import Any, Callable, NewType, Optional, Protocol, TypeVar
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
import torch
|
|
|
import torch._dynamo.logging
|
|
|
import torch.nn as nn
|
|
|
import torch.utils._pytree as pytree
|
|
|
import torch.utils.dlpack
|
|
|
from torch import Tensor
|
|
|
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
|
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
|
from torch._dynamo import compiled_autograd
|
|
|
from torch._dynamo.utils import (
|
|
|
CompileEventLogger,
|
|
|
dynamo_timed,
|
|
|
preserve_rng_state,
|
|
|
set_feature_use,
|
|
|
)
|
|
|
from torch._guards import detect_fake_mode
|
|
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
|
|
from torch._inductor.output_code import OutputCode
|
|
|
from torch._inductor.utils import BoxedBool, InputType
|
|
|
from torch._subclasses import FakeTensor, FakeTensorMode
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
_pytree_subclasses_that_lose_info,
|
|
|
make_fx,
|
|
|
)
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
|
|
|
|
|
static_inputs_log = torch._logging.getArtifactLogger(
|
|
|
__name__, "cudagraph_static_inputs"
|
|
|
)
|
|
|
from . import config
|
|
|
from ._aot_autograd.autograd_cache import (
|
|
|
AOTAutogradCache,
|
|
|
autograd_cache_key,
|
|
|
should_use_local_autograd_cache,
|
|
|
should_use_remote_autograd_cache,
|
|
|
)
|
|
|
from ._aot_autograd.collect_metadata_analysis import (
|
|
|
run_functionalized_fw_and_collect_metadata,
|
|
|
)
|
|
|
from ._aot_autograd.functional_utils import (
|
|
|
_check_if_mutation_can_be_in_graph,
|
|
|
are_all_mutations_hidden_from_autograd,
|
|
|
are_all_mutations_under_no_grad_or_inference_mode,
|
|
|
assert_functional_graph,
|
|
|
from_fun,
|
|
|
gen_alias_from_base,
|
|
|
has_data_mutation,
|
|
|
has_metadata_mutation,
|
|
|
is_fun,
|
|
|
sync_functional_tensor,
|
|
|
to_fun,
|
|
|
)
|
|
|
from ._aot_autograd.input_output_analysis import (
|
|
|
compute_overlapping_inputs,
|
|
|
create_graph_signature,
|
|
|
create_synthetic_base_metadata,
|
|
|
remove_dupe_metadata,
|
|
|
)
|
|
|
from ._aot_autograd.jit_compile_runtime_wrappers import (
|
|
|
aot_dispatch_autograd,
|
|
|
aot_dispatch_base,
|
|
|
aot_dispatch_export,
|
|
|
)
|
|
|
from ._aot_autograd.logging_utils import (
|
|
|
callback_set,
|
|
|
describe_input,
|
|
|
format_guard_bug_msg,
|
|
|
get_aot_compilation_context,
|
|
|
get_aot_graph_name,
|
|
|
get_graph_being_compiled,
|
|
|
graph_being_compiled,
|
|
|
model_name,
|
|
|
nth_graph,
|
|
|
set_model_name,
|
|
|
setup_stacktrace_preservation_hooks,
|
|
|
track_graph_compiling,
|
|
|
)
|
|
|
from ._aot_autograd.runtime_wrappers import (
|
|
|
AOTDedupeWrapper,
|
|
|
AOTSyntheticBaseWrapper,
|
|
|
)
|
|
|
from ._aot_autograd.schemas import (
|
|
|
AOTConfig,
|
|
|
BackwardSignature,
|
|
|
FQN,
|
|
|
GraphInputName,
|
|
|
GraphOutputName,
|
|
|
GraphSignature,
|
|
|
InputAliasInfo,
|
|
|
MutationType,
|
|
|
OutputAliasInfo,
|
|
|
OutputType,
|
|
|
SubclassCreationMeta,
|
|
|
SubclassMeta,
|
|
|
TensorAlias,
|
|
|
ViewAndMutationMeta,
|
|
|
)
|
|
|
from ._aot_autograd.subclass_utils import (
|
|
|
requires_subclass_dispatch,
|
|
|
unwrap_tensor_subclasses,
|
|
|
unwrap_tensor_subclasses_with_indices_to_original,
|
|
|
wrap_tensor_subclasses,
|
|
|
wrap_tensor_subclasses_maybe_joint,
|
|
|
)
|
|
|
from ._aot_autograd.traced_function_transforms import (
|
|
|
aot_dispatch_subclass,
|
|
|
create_functional_call,
|
|
|
create_functionalized_fn,
|
|
|
create_functionalized_rng_ops_wrapper,
|
|
|
create_joint,
|
|
|
fn_input_mutations_to_outputs,
|
|
|
fn_prepped_for_autograd,
|
|
|
)
|
|
|
from ._aot_autograd.utils import (
|
|
|
_get_autocast_states,
|
|
|
_get_symint_hints,
|
|
|
call_func_at_runtime_with_args,
|
|
|
create_tree_flattened_fn,
|
|
|
KNOWN_TYPES,
|
|
|
make_boxed_compiler,
|
|
|
make_boxed_func,
|
|
|
maybe_to_fresh_input,
|
|
|
normalize_as_list,
|
|
|
partial_flatten_asdict,
|
|
|
root_module_when_exporting_non_strict,
|
|
|
strict_zip,
|
|
|
)
|
|
|
from .partitioners import default_partition
|
|
|
|
|
|
|
|
|
zip = strict_zip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AOT_COUNTER = itertools.count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aot_autograd_decompositions = {}
|
|
|
|
|
|
FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any])
|
|
|
|
|
|
|
|
|
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
|
|
|
|
|
|
|
|
|
class AOTDispatchCompiler(Protocol):
|
|
|
"""
|
|
|
Represents a fw or bw_compiler passed to AOTAutograd.
|
|
|
"""
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
gm: torch.fx.GraphModule,
|
|
|
example_inputs: Sequence[InputType],
|
|
|
) -> Any:
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
|
|
"""
|
|
|
Represents an AOTDispatchCompiler that returns an OutputCode, and is
|
|
|
therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode.
|
|
|
A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of
|
|
|
the kwargs in _CompileFxKwargs.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
output_code_ty: type[TOutputCode],
|
|
|
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
|
|
):
|
|
|
self.output_code_ty = output_code_ty
|
|
|
self.compiler_fn = compiler_fn
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
gm: torch.fx.GraphModule,
|
|
|
example_inputs: Sequence[InputType],
|
|
|
) -> OutputCode:
|
|
|
return self.compiler_fn(gm, example_inputs)
|
|
|
|
|
|
|
|
|
def process_inputs(
|
|
|
flat_args: list[Any],
|
|
|
aot_config: AOTConfig,
|
|
|
fake_mode: FakeTensorMode,
|
|
|
shape_env: Optional[ShapeEnv],
|
|
|
ignore_shape_env: bool = False,
|
|
|
) -> FakifiedFlatArgs:
|
|
|
with fake_mode:
|
|
|
|
|
|
def convert(idx, x):
|
|
|
if shape_env is not None and not ignore_shape_env:
|
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
|
|
if isinstance(x, int):
|
|
|
|
|
|
if aot_config.is_export:
|
|
|
return x
|
|
|
source = ConstantSource(f"sym_{idx}")
|
|
|
return shape_env.create_symintnode(
|
|
|
shape_env.create_symbol(x, source), hint=x, source=source
|
|
|
)
|
|
|
if isinstance(x, torch.ScriptObject):
|
|
|
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
|
|
fake_mode, x
|
|
|
)
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
return x
|
|
|
if isinstance(x, FakeTensor):
|
|
|
assert x.fake_mode is fake_mode
|
|
|
return x
|
|
|
if is_traceable_wrapper_subclass(x):
|
|
|
attrs, _ = x.__tensor_flatten__()
|
|
|
if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):
|
|
|
assert all(
|
|
|
getattr(x, attr).fake_mode is fake_mode for attr in attrs
|
|
|
)
|
|
|
return x
|
|
|
|
|
|
|
|
|
symbolic_context = None
|
|
|
source = None
|
|
|
trace = True
|
|
|
if tracing_context := torch._guards.TracingContext.try_get():
|
|
|
if x in tracing_context.tensor_to_context:
|
|
|
symbolic_context = tracing_context.tensor_to_context[x]
|
|
|
source = symbolic_context.tensor_source
|
|
|
|
|
|
|
|
|
trace = False
|
|
|
if (
|
|
|
idx < aot_config.num_params_buffers
|
|
|
and config.static_weight_shapes
|
|
|
and not symbolic_context
|
|
|
):
|
|
|
|
|
|
|
|
|
return fake_mode.from_tensor(x, static_shapes=True)
|
|
|
|
|
|
result = fake_mode.from_tensor(
|
|
|
x,
|
|
|
static_shapes=ignore_shape_env,
|
|
|
symbolic_context=symbolic_context,
|
|
|
source=source,
|
|
|
trace=trace,
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)])
|
|
|
|
|
|
|
|
|
def construct_fake_mode(
|
|
|
flat_args: list[Any], aot_config: AOTConfig
|
|
|
) -> tuple[FakeTensorMode, Optional[ShapeEnv]]:
|
|
|
fake_mode = detect_fake_mode(flat_args)
|
|
|
if fake_mode is None:
|
|
|
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
|
|
|
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
|
else:
|
|
|
shape_env = fake_mode.shape_env
|
|
|
return (fake_mode, shape_env)
|
|
|
|
|
|
|
|
|
def create_aot_dispatcher_function(
|
|
|
flat_fn,
|
|
|
fake_flat_args: FakifiedFlatArgs,
|
|
|
aot_config: AOTConfig,
|
|
|
fake_mode: FakeTensorMode,
|
|
|
shape_env: Optional[ShapeEnv],
|
|
|
) -> tuple[Callable, ViewAndMutationMeta]:
|
|
|
with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True):
|
|
|
return _create_aot_dispatcher_function(
|
|
|
flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
|
|
|
)
|
|
|
|
|
|
|
|
|
def _create_aot_dispatcher_function(
|
|
|
flat_fn,
|
|
|
fake_flat_args: FakifiedFlatArgs,
|
|
|
aot_config: AOTConfig,
|
|
|
fake_mode: FakeTensorMode,
|
|
|
shape_env: Optional[ShapeEnv],
|
|
|
) -> tuple[Callable, ViewAndMutationMeta]:
|
|
|
"""
|
|
|
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
|
|
|
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
|
|
|
the tracing mechanism to understand the graph capturing details.
|
|
|
|
|
|
The joint graph is then passed through attr:`partition_fn` to isolate the
|
|
|
forward and backward portions, which are then respectively compiled via the
|
|
|
provided attr:`fw_compiler` and attr:`bw_compiler`.
|
|
|
|
|
|
The resulting compiled forward and backward graphs are then wrapped up in a
|
|
|
``torch.autograd.Function`` object.
|
|
|
|
|
|
The calling convention here is that the first aot_config.num_params_buffers
|
|
|
inputs in flat_args are parameters and buffers, and the rest are inputs.
|
|
|
|
|
|
We use this to assume that parameters/buffer's shapes don't change.
|
|
|
|
|
|
Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export)
|
|
|
When aot_config.is_export is True, we return an FX graph + metadata
|
|
|
When aot_config.is_export is False, we return an ordinary runtime function
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if aot_config.decompositions is None:
|
|
|
aot_config.decompositions = {}
|
|
|
|
|
|
aot_config.decompositions = {
|
|
|
**aot_autograd_decompositions,
|
|
|
**aot_config.decompositions,
|
|
|
}
|
|
|
|
|
|
if config.functionalize_rng_ops:
|
|
|
|
|
|
aot_config.decompositions = {
|
|
|
**rng_decompositions,
|
|
|
**aot_config.decompositions,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
python_dispatcher_mode = (
|
|
|
enable_python_dispatcher() if shape_env is not None else nullcontext()
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.autograd.set_multithreading_enabled(
|
|
|
False
|
|
|
), preserve_rng_state(), (
|
|
|
fake_mode
|
|
|
), (
|
|
|
python_dispatcher_mode
|
|
|
), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
|
|
|
from torch._library.fake_class_registry import (
|
|
|
FakeScriptObject,
|
|
|
maybe_to_fake_obj,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dup_fake_script_obj(fake_flat_args):
|
|
|
return [
|
|
|
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
|
|
|
if isinstance(arg, FakeScriptObject)
|
|
|
else arg
|
|
|
for arg in fake_flat_args
|
|
|
]
|
|
|
|
|
|
needs_autograd = any(
|
|
|
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
|
|
|
)
|
|
|
|
|
|
with enable_python_dispatcher():
|
|
|
|
|
|
|
|
|
with patch("torch.cuda.set_rng_state", lambda *args: None):
|
|
|
mod = root_module_when_exporting_non_strict(flat_fn)
|
|
|
if mod is not None:
|
|
|
ctx = _detect_attribute_assignment(mod)
|
|
|
else:
|
|
|
ctx = nullcontext()
|
|
|
|
|
|
if torch._functorch.config.fake_tensor_propagate_real_tensors:
|
|
|
|
|
|
|
|
|
dynamo_timed_ctx = nullcontext()
|
|
|
else:
|
|
|
dynamo_timed_ctx = dynamo_timed(
|
|
|
"aot_collect_metadata", log_pt2_compile_event=True
|
|
|
)
|
|
|
|
|
|
with dynamo_timed_ctx, ctx:
|
|
|
fw_metadata = run_functionalized_fw_and_collect_metadata(
|
|
|
flat_fn,
|
|
|
static_input_indices=aot_config.static_input_indices,
|
|
|
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
|
|
is_train=needs_autograd,
|
|
|
pre_dispatch=aot_config.pre_dispatch,
|
|
|
is_export=aot_config.is_export,
|
|
|
)(*_dup_fake_script_obj(fake_flat_args))
|
|
|
|
|
|
req_subclass_dispatch = requires_subclass_dispatch(
|
|
|
fake_flat_args, fw_metadata
|
|
|
)
|
|
|
CompileEventLogger.try_add_pt2_compile(
|
|
|
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
|
|
|
)
|
|
|
|
|
|
output_and_mutation_safe = not any(
|
|
|
x.requires_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
and not (
|
|
|
x.output_type
|
|
|
in (OutputType.alias_of_input, OutputType.is_input)
|
|
|
and fw_metadata.input_info[x.base_idx].requires_grad
|
|
|
)
|
|
|
for x in fw_metadata.output_info
|
|
|
) and not any(
|
|
|
x.requires_grad
|
|
|
and x.mutates_data
|
|
|
and not x.mutations_under_no_grad_or_inference_mode
|
|
|
and not x.mutations_hidden_from_autograd
|
|
|
for x in fw_metadata.input_info
|
|
|
)
|
|
|
|
|
|
if needs_autograd and output_and_mutation_safe:
|
|
|
|
|
|
|
|
|
|
|
|
needs_autograd = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if req_subclass_dispatch:
|
|
|
fw_metadata = run_functionalized_fw_and_collect_metadata(
|
|
|
flat_fn,
|
|
|
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
|
|
is_train=False,
|
|
|
pre_dispatch=aot_config.pre_dispatch,
|
|
|
static_input_indices=aot_config.static_input_indices,
|
|
|
)(*fake_flat_args)
|
|
|
else:
|
|
|
fw_metadata = ViewAndMutationMeta(
|
|
|
input_info=fw_metadata.input_info,
|
|
|
output_info=fw_metadata.output_info,
|
|
|
num_intermediate_bases=fw_metadata.num_intermediate_bases,
|
|
|
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
|
|
traced_tangents=fw_metadata.traced_tangents,
|
|
|
subclass_inp_meta=fw_metadata.subclass_inp_meta,
|
|
|
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
|
|
|
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
|
|
|
is_train=False,
|
|
|
tokens=fw_metadata.tokens,
|
|
|
static_input_indices=fw_metadata.static_input_indices,
|
|
|
)
|
|
|
|
|
|
if fw_metadata.num_intermediate_bases > 0:
|
|
|
assert not req_subclass_dispatch, f"""\
|
|
|
torch.compile is currently being used with tensor subclass inputs:
|
|
|
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
|
|
|
that alias one another, which is currently unsupported in the subclass use case. If you run into this,
|
|
|
please file a github issue"""
|
|
|
|
|
|
if aot_config.is_export:
|
|
|
|
|
|
|
|
|
|
|
|
if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0:
|
|
|
raise RuntimeError(
|
|
|
f"""\
|
|
|
Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`.
|
|
|
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
|
|
|
|
|
|
fw_metadata={str(fw_metadata)}"""
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
len(
|
|
|
[
|
|
|
x
|
|
|
for x in fw_metadata.input_info
|
|
|
if x.requires_grad and x.mutates_data
|
|
|
]
|
|
|
)
|
|
|
!= 0
|
|
|
):
|
|
|
raise RuntimeError(
|
|
|
f"""\
|
|
|
Found a graph input that requires gradients, and received a mutation.
|
|
|
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
|
|
|
|
|
|
fw_metadata={str(fw_metadata)}"""
|
|
|
)
|
|
|
if req_subclass_dispatch:
|
|
|
raise RuntimeError(
|
|
|
"""\
|
|
|
aot_export is not currently supported with traceable tensor subclass.
|
|
|
If you need this feature, please comment on <CREATE_ISSUE_LINK>"""
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if config.functionalize_rng_ops:
|
|
|
raise RuntimeError(
|
|
|
"""\
|
|
|
Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue,
|
|
|
or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
|
|
|
)
|
|
|
|
|
|
def choose_dispatcher(needs_autograd, aot_config):
|
|
|
"""
|
|
|
Pick a dispatcher based on the config rules.
|
|
|
"""
|
|
|
if aot_config.is_export:
|
|
|
|
|
|
|
|
|
CompileEventLogger.try_add_pt2_compile(
|
|
|
"backend_compile", dispatch_mode="export"
|
|
|
)
|
|
|
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
|
|
|
elif needs_autograd and not aot_config.pre_dispatch:
|
|
|
CompileEventLogger.try_add_pt2_compile(
|
|
|
"backend_compile", dispatch_mode="autograd"
|
|
|
)
|
|
|
return aot_dispatch_autograd
|
|
|
else:
|
|
|
CompileEventLogger.try_add_pt2_compile(
|
|
|
"backend_compile", dispatch_mode="inference"
|
|
|
)
|
|
|
return aot_dispatch_base
|
|
|
|
|
|
compiler_fn = choose_dispatcher(needs_autograd, aot_config)
|
|
|
|
|
|
compiled_fn, fw_metadata = compiler_fn(
|
|
|
flat_fn,
|
|
|
_dup_fake_script_obj(fake_flat_args),
|
|
|
aot_config,
|
|
|
fw_metadata=fw_metadata,
|
|
|
)
|
|
|
return compiled_fn, fw_metadata
|
|
|
|
|
|
|
|
|
def aot_function(
|
|
|
fn: Callable,
|
|
|
fw_compiler: Callable,
|
|
|
bw_compiler: Optional[Callable] = None,
|
|
|
partition_fn: Callable = default_partition,
|
|
|
decompositions: Optional[dict] = None,
|
|
|
num_params_buffers: int = 0,
|
|
|
keep_inference_input_mutations: bool = False,
|
|
|
inference_compiler: Optional[Callable] = None,
|
|
|
*,
|
|
|
|
|
|
dynamic=False,
|
|
|
enable_log=True,
|
|
|
) -> Callable:
|
|
|
"""
|
|
|
Traces the forward and backward graph of :attr:`fn` using torch dispatch
|
|
|
mechanism, and then compiles the generated forward and backward graphs
|
|
|
through :attr:`fw_compiler` and :attr:`bw_compiler`.
|
|
|
|
|
|
:func:`aot_function` traces the forward and backward graph ahead of time,
|
|
|
and generates a joint forward and backward graph. :attr:`partition_fn` is
|
|
|
then used to separate out forward and backward graphs. The partitioner
|
|
|
function can be used to perform optimizations such as recomputation. One can
|
|
|
set `decompositions` dictionary to decompose the operators into a sequence
|
|
|
of core or simpler operators supported by the backend compilers.
|
|
|
|
|
|
.. warning::
|
|
|
This API is experimental and likely to change.
|
|
|
|
|
|
Args:
|
|
|
fn (Callable): A Python function that takes one ore more arguments. Must
|
|
|
return one or more Tensors.
|
|
|
fw_compiler (Callable): A Python function that accepts an Fx graph with
|
|
|
Aten ops and input args, and returns a Callable that semantically is
|
|
|
equivalent to the input Fx graph.
|
|
|
bw_compiler (Optional[Callable]): A Python function that accepts an
|
|
|
Fx graph with Aten ops and input args, and returns a Callable that
|
|
|
semantically is equivalent to the input Fx graph. Default: None
|
|
|
(when None, it defaults to the :attr:`fw_compiler`)
|
|
|
partition_fn (Callable): A Python function that takes a joint forward
|
|
|
and backward graph, and partitions it into separate forward and
|
|
|
backward graphs.
|
|
|
decompositions (Dict): A dictionary to define the decomposition of
|
|
|
larger Aten ops into simpler or core Aten ops.
|
|
|
inference_compiler (Optional[Callable]): A Python function that accepts an
|
|
|
Fx graph with Aten ops and input args, and returns a Callable that
|
|
|
semantically is equivalent to the input Fx graph. inference_compiler is invoked
|
|
|
if no autograd is needed. Default: None
|
|
|
(when None, it defaults to the :attr:`fw_compiler`)
|
|
|
Returns:
|
|
|
Returns a ``Callable`` that retains the eager behavior of the original
|
|
|
:attr:`fn`, but with forward and backward graph compiled via
|
|
|
:attr:`fw_compile` and :attr:`bw_compile`.
|
|
|
|
|
|
A simple example usage of :func:`aot_function` is as follows. This example
|
|
|
will print the forward and backward graphs of the function ``fn``
|
|
|
|
|
|
>>> fn = lambda x : x.sin().cos()
|
|
|
>>> def print_compile_fn(fx_module, args):
|
|
|
>>> print(fx_module)
|
|
|
>>> return fx_module
|
|
|
>>> aot_fn = aot_function(fn, print_compile_fn)
|
|
|
>>> x = torch.randn(4, 5, requires_grad=True)
|
|
|
>>> aot_fn(x)
|
|
|
"""
|
|
|
|
|
|
if bw_compiler is None:
|
|
|
bw_compiler = fw_compiler
|
|
|
if inference_compiler is None:
|
|
|
inference_compiler = fw_compiler
|
|
|
aot_config = AOTConfig(
|
|
|
fw_compiler=fw_compiler,
|
|
|
bw_compiler=bw_compiler,
|
|
|
inference_compiler=inference_compiler,
|
|
|
partition_fn=partition_fn,
|
|
|
decompositions=decompositions,
|
|
|
num_params_buffers=num_params_buffers,
|
|
|
aot_id=next(AOT_COUNTER),
|
|
|
keep_inference_input_mutations=keep_inference_input_mutations,
|
|
|
dynamic_shapes=dynamic,
|
|
|
aot_autograd_arg_pos_to_source=None,
|
|
|
is_export=False,
|
|
|
no_tangents=False,
|
|
|
enable_log=enable_log,
|
|
|
)
|
|
|
cached_res = None
|
|
|
|
|
|
@wraps(fn)
|
|
|
def returned_function(*args, **kwargs):
|
|
|
nonlocal cached_res
|
|
|
|
|
|
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
|
|
|
|
|
|
|
|
|
if cached_res is None:
|
|
|
flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs)
|
|
|
(fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config)
|
|
|
fake_flat_args: FakifiedFlatArgs = process_inputs(
|
|
|
flat_args, aot_config, fake_mode, shape_env
|
|
|
)
|
|
|
compiled_fn, _ = create_aot_dispatcher_function(
|
|
|
flat_fn,
|
|
|
fake_flat_args,
|
|
|
aot_config,
|
|
|
fake_mode,
|
|
|
shape_env,
|
|
|
)
|
|
|
cached_res = (compiled_fn, out_spec)
|
|
|
|
|
|
cached_fn, out_spec = cached_res
|
|
|
out = cached_fn(flat_args)
|
|
|
return out_spec.unflatten(out)
|
|
|
|
|
|
return returned_function
|
|
|
|
|
|
|
|
|
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|
|
"""
|
|
|
Traces the forward and backward graph of :attr:`mod` using torch dispatch
|
|
|
tracing mechanism. It is wrapper function, that underneath uses
|
|
|
:func:`aot_function` to perform tracing and compilation.
|
|
|
|
|
|
:func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
|
|
|
to a new callable which is then compiled through :func:`aot_function`.
|
|
|
|
|
|
.. warning::
|
|
|
This API is experimental and likely to change.
|
|
|
|
|
|
Args:
|
|
|
mod (Callable): A ``nn.Module`` module.
|
|
|
args : args to be passed to :func:`aot_function`
|
|
|
kwargs : kwargs to be passed to :func:`aot_function`
|
|
|
|
|
|
Returns:
|
|
|
Returns a ``nn.Module`` that retains the eager behavior of the original
|
|
|
:attr:`mod`, but with forward and backward graph compiled.
|
|
|
|
|
|
"""
|
|
|
|
|
|
torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
|
|
|
|
|
|
def functional_call(named_params, named_buffers, *args, **kwargs):
|
|
|
params_and_buffers = {**named_params, **named_buffers}
|
|
|
return torch.func.functional_call(mod, params_and_buffers, args, kwargs)
|
|
|
|
|
|
named_params = dict(mod.named_parameters(remove_duplicate=False))
|
|
|
named_buffers = dict(mod.named_buffers(remove_duplicate=False))
|
|
|
num_params_buffers = len(named_params) + len(named_buffers)
|
|
|
compiled_f = aot_function(
|
|
|
functional_call, *args, num_params_buffers=num_params_buffers, **kwargs
|
|
|
)
|
|
|
|
|
|
class AOTModule(nn.Module):
|
|
|
def __init__(self) -> None:
|
|
|
super().__init__()
|
|
|
self.orig_module = mod
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
return compiled_f(
|
|
|
named_params,
|
|
|
named_buffers,
|
|
|
*args,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
return AOTModule()
|
|
|
|
|
|
|
|
|
def _try_get_metadata_from_dynamo(
|
|
|
mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int
|
|
|
) -> tuple[Optional[list[torch._guards.Source]], list[int]]:
|
|
|
"""
|
|
|
Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule.
|
|
|
We first verify that `mod` does come from Dynamo, then we handle cases where
|
|
|
metadata might be missing.
|
|
|
|
|
|
Returns:
|
|
|
aot_autograd_arg_pos_to_source: used to dedup params and their guards
|
|
|
static_input_indices: used to identify static inputs for cudagraphs
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
|
|
|
|
|
|
return None, []
|
|
|
|
|
|
if not hasattr(mod, "_param_name_to_source"):
|
|
|
|
|
|
return None, []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
param_name_to_source = mod._param_name_to_source
|
|
|
seen_sources = set()
|
|
|
|
|
|
aot_autograd_arg_pos_to_source = []
|
|
|
static_input_indices = []
|
|
|
|
|
|
for i, name in enumerate(param_keys):
|
|
|
assert name in param_name_to_source, f"{name} not found."
|
|
|
source = param_name_to_source[name]
|
|
|
assert source not in seen_sources, source
|
|
|
seen_sources.add(source)
|
|
|
aot_autograd_arg_pos_to_source.append(source)
|
|
|
|
|
|
static_input_indices.append(i)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
|
|
|
assert hasattr(node, "_dynamo_source")
|
|
|
source = node._dynamo_source
|
|
|
|
|
|
|
|
|
|
|
|
assert source is None or source not in seen_sources, source
|
|
|
seen_sources.add(source)
|
|
|
aot_autograd_arg_pos_to_source.append(source)
|
|
|
source_name = source.name() if source else str(source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actual_pos = pos + len(param_keys)
|
|
|
|
|
|
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
|
|
|
"_dynamo_static_input_type", None
|
|
|
):
|
|
|
static_inputs_log.debug(
|
|
|
"Adding static input pos %s for source %s", actual_pos, source_name
|
|
|
)
|
|
|
static_input_indices.append(actual_pos)
|
|
|
else:
|
|
|
static_inputs_log.debug(
|
|
|
"Non-static input pos %s for source %s", actual_pos, source_name
|
|
|
)
|
|
|
|
|
|
assert full_args_num == len(aot_autograd_arg_pos_to_source)
|
|
|
return aot_autograd_arg_pos_to_source, static_input_indices
|
|
|
|
|
|
|
|
|
def aot_module_simplified(
|
|
|
mod: nn.Module,
|
|
|
args,
|
|
|
fw_compiler: AOTDispatchCompiler,
|
|
|
bw_compiler: Optional[AOTDispatchCompiler] = None,
|
|
|
partition_fn: Callable = default_partition,
|
|
|
decompositions: Optional[dict] = None,
|
|
|
keep_inference_input_mutations=False,
|
|
|
inference_compiler: Optional[AOTDispatchCompiler] = None,
|
|
|
cudagraphs: Optional[BoxedBool] = None,
|
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
|
|
ignore_shape_env: bool = False,
|
|
|
) -> nn.Module:
|
|
|
"""
|
|
|
This is the simplified or low overhead version of aot_module. For frontends
|
|
|
like TorchDynamo, the input functions/modules to AOT are static and have
|
|
|
unpacked inputs/outputs. This gives us an opportunity to remove the
|
|
|
(1) pytree overhead to parse inputs/outputs,
|
|
|
(2) AOT Autograd cache,
|
|
|
(3) Reading of params/buffers in every forward call
|
|
|
|
|
|
:func:`aot_module_simplified` removes these overheads.
|
|
|
"""
|
|
|
params = {
|
|
|
**dict(mod.named_parameters(remove_duplicate=False)),
|
|
|
**dict(mod.named_buffers(remove_duplicate=False)),
|
|
|
}
|
|
|
params_flat, params_spec = pytree.tree_flatten(params)
|
|
|
params_flat = list(params_flat)
|
|
|
params_len = len(params_flat)
|
|
|
|
|
|
if cudagraphs is None:
|
|
|
cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
|
|
|
|
|
|
if bw_compiler is None:
|
|
|
bw_compiler = fw_compiler
|
|
|
if inference_compiler is None:
|
|
|
inference_compiler = fw_compiler
|
|
|
|
|
|
full_args = []
|
|
|
|
|
|
full_args.extend(params_flat)
|
|
|
|
|
|
if tracing_context := torch._guards.TracingContext.try_get():
|
|
|
tracing_context.params_flat = params_flat
|
|
|
(
|
|
|
tracing_context.params_flat_unwrap_subclasses,
|
|
|
tracing_context.params_unwrapped_to_flat_index,
|
|
|
) = unwrap_tensor_subclasses_with_indices_to_original(params_flat)
|
|
|
|
|
|
|
|
|
full_args.extend(args)
|
|
|
|
|
|
(
|
|
|
aot_autograd_arg_pos_to_source,
|
|
|
static_input_indices,
|
|
|
) = _try_get_metadata_from_dynamo(mod, params.keys(), len(full_args))
|
|
|
|
|
|
dynamic_shapes = False
|
|
|
for x in full_args:
|
|
|
if isinstance(x, FakeTensor):
|
|
|
dynamic_shapes = x.fake_mode.shape_env is not None
|
|
|
break
|
|
|
|
|
|
aot_config = AOTConfig(
|
|
|
fw_compiler=fw_compiler,
|
|
|
bw_compiler=bw_compiler,
|
|
|
inference_compiler=inference_compiler,
|
|
|
partition_fn=partition_fn,
|
|
|
decompositions=decompositions,
|
|
|
num_params_buffers=params_len,
|
|
|
aot_id=next(AOT_COUNTER),
|
|
|
keep_inference_input_mutations=keep_inference_input_mutations,
|
|
|
dynamic_shapes=dynamic_shapes,
|
|
|
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
|
|
|
static_input_indices=static_input_indices,
|
|
|
is_export=False,
|
|
|
no_tangents=False,
|
|
|
cache_info=None,
|
|
|
ignore_shape_env=ignore_shape_env,
|
|
|
precompile_backend_id=getattr(mod, "_backend_id", None),
|
|
|
)
|
|
|
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
|
|
fake_flat_args = process_inputs(
|
|
|
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
|
|
|
)
|
|
|
|
|
|
def dispatch_and_compile():
|
|
|
functional_call = create_functional_call(mod, params_spec, params_len)
|
|
|
with compiled_autograd._disable():
|
|
|
compiled_fn, _ = create_aot_dispatcher_function(
|
|
|
functional_call,
|
|
|
fake_flat_args,
|
|
|
aot_config,
|
|
|
fake_mode,
|
|
|
shape_env,
|
|
|
)
|
|
|
return compiled_fn
|
|
|
|
|
|
|
|
|
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
|
|
|
local = should_use_local_autograd_cache()
|
|
|
remote = should_use_remote_autograd_cache()
|
|
|
if local or remote:
|
|
|
set_feature_use("aot_autograd_remote_cache", remote)
|
|
|
compiled_fn = AOTAutogradCache.load(
|
|
|
dispatch_and_compile,
|
|
|
mod,
|
|
|
fake_flat_args,
|
|
|
aot_config,
|
|
|
cudagraphs,
|
|
|
boxed_forward_device_index,
|
|
|
local,
|
|
|
remote,
|
|
|
)
|
|
|
else:
|
|
|
compiled_fn = dispatch_and_compile()
|
|
|
else:
|
|
|
compiled_fn = dispatch_and_compile()
|
|
|
|
|
|
if isinstance(mod, torch._dynamo.utils.GmWrapper):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def boxed_forward(runtime_args: list[Any]):
|
|
|
flat_args = []
|
|
|
flat_args.extend(params_flat)
|
|
|
flat_args.extend(runtime_args)
|
|
|
runtime_args.clear()
|
|
|
return compiled_fn(flat_args)
|
|
|
|
|
|
|
|
|
boxed_forward.zero_grad = mod.zero_grad
|
|
|
boxed_forward.named_parameters = mod.named_parameters
|
|
|
boxed_forward.named_buffers = mod.named_buffers
|
|
|
return boxed_forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(*runtime_args: tuple[Any]):
|
|
|
full_args = []
|
|
|
full_args.extend(params_flat)
|
|
|
full_args.extend(runtime_args)
|
|
|
return compiled_fn(full_args)
|
|
|
|
|
|
|
|
|
forward.zero_grad = mod.zero_grad
|
|
|
forward.named_parameters = mod.named_parameters
|
|
|
forward.named_buffers = mod.named_buffers
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
def aot_export_module(
|
|
|
mod: nn.Module,
|
|
|
args,
|
|
|
*,
|
|
|
decompositions: Optional[dict] = None,
|
|
|
|
|
|
|
|
|
trace_joint: bool,
|
|
|
|
|
|
|
|
|
output_loss_index: Optional[int] = None,
|
|
|
pre_dispatch: bool = False,
|
|
|
|
|
|
dynamic_shapes: Optional[bool] = None,
|
|
|
kwargs=None,
|
|
|
) -> tuple[torch.fx.GraphModule, GraphSignature]:
|
|
|
"""
|
|
|
This function takes in a module, and returns:
|
|
|
(1) an FX graph that can be exported
|
|
|
(2) some metadata about the graph
|
|
|
|
|
|
If `trace_joint=True` we will return a joint graph of the forward + backward.
|
|
|
|
|
|
The traced FX graph will have the following properties compared to the original module:
|
|
|
(1) Inputs and outputs to the module will be pytree-flattened
|
|
|
(2) Parameters and buffers on the module will be lifted into graph inputs,
|
|
|
graph_inputs = (*parameters, *buffers, *user_inputs)
|
|
|
(3) The graph will be fully functionalized
|
|
|
(4) Any input mutations will be converted into additional outputs in the graph,
|
|
|
meaning whoever calls this graph is responsible for applying the mutations
|
|
|
back to the original inputs.
|
|
|
(5) If is_joint is provided the graph will return parameter gradients in addition to user outputs.
|
|
|
The graph output will look like:
|
|
|
graph_outputs = (*updated_inputs, *user_outputs, *param_gradients)
|
|
|
|
|
|
There are also several restrictions on what modules can use this API. In particular:
|
|
|
(1) If trace_joint is specified, we expect the loss function to be **fused**
|
|
|
into the module forward. One of the outputs to the forward must be a scalar loss,
|
|
|
which is specified with `output_loss_index`.
|
|
|
All other outputs to the forward are presumed to not require gradients.
|
|
|
(2) This API cannot capture optimizers (although in theory we could build an API for this).
|
|
|
(3) Metadata mutations on params/buffers/inputs are banned.
|
|
|
(4) Data mutations on anything that requires gradients are banned (parameters)
|
|
|
(5) If an input is mutated, it is not allowed to alias any other inputs.
|
|
|
(6) Parameters must not be duplicated.
|
|
|
"""
|
|
|
if pre_dispatch and trace_joint:
|
|
|
raise RuntimeError("pre_dispatch is not supported when trace_joint is True.")
|
|
|
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
|
|
|
named_buffers = dict(mod.named_buffers(remove_duplicate=False))
|
|
|
|
|
|
params_and_buffers = {
|
|
|
**dict(named_parameters),
|
|
|
**dict(named_buffers),
|
|
|
}
|
|
|
params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
|
|
|
params_and_buffers_flat = tuple(params_and_buffers_flat)
|
|
|
params_len = len(params_and_buffers_flat)
|
|
|
|
|
|
kwargs = kwargs or {}
|
|
|
|
|
|
functional_call = create_functional_call(
|
|
|
mod, params_spec, params_len, store_orig_mod=True
|
|
|
)
|
|
|
|
|
|
num_fw_outs = None
|
|
|
|
|
|
if trace_joint:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fn_to_trace(*args):
|
|
|
nonlocal num_fw_outs
|
|
|
out = functional_call(*args)
|
|
|
if output_loss_index is None:
|
|
|
raise RuntimeError(
|
|
|
"""\
|
|
|
If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss.
|
|
|
You must specify the which (index) output is the loss with output_loss_index."""
|
|
|
)
|
|
|
if isinstance(out, (torch.Tensor)):
|
|
|
out = (out,)
|
|
|
if not isinstance(out, (tuple, list)):
|
|
|
raise RuntimeError(
|
|
|
f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}"
|
|
|
)
|
|
|
|
|
|
for i, o in enumerate(out):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if o.requires_grad and i != output_loss_index:
|
|
|
raise RuntimeError(
|
|
|
f"""\
|
|
|
Found an output of the forward that requires gradients, that was not the scalar loss.
|
|
|
We require all outputs to the forward that are not the scalar loss to not require gradient,
|
|
|
because we will only compute a backward graph against the scalar loss.
|
|
|
You can fix this by calling .detach() on each of your forward outputs that is not the loss.
|
|
|
You specified that output index {output_loss_index} is the loss, but we found that
|
|
|
the output at index {i} requires gradients."""
|
|
|
)
|
|
|
out_loss = out[output_loss_index]
|
|
|
num_fw_outs = len(out)
|
|
|
if not out_loss.requires_grad:
|
|
|
raise RuntimeError(
|
|
|
f"""\
|
|
|
The output at index {output_loss_index} was marked as the loss, but it does not require gradients"""
|
|
|
)
|
|
|
if out_loss.numel() != 1:
|
|
|
raise RuntimeError(
|
|
|
f"""\
|
|
|
We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}"""
|
|
|
)
|
|
|
return out
|
|
|
|
|
|
ctx = nullcontext
|
|
|
else:
|
|
|
|
|
|
|
|
|
ctx = nullcontext if pre_dispatch else torch.no_grad
|
|
|
fn_to_trace = functional_call
|
|
|
|
|
|
full_args = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_args.extend(params_and_buffers_flat)
|
|
|
|
|
|
full_args.extend(args)
|
|
|
|
|
|
with ctx():
|
|
|
fx_g, metadata, in_spec, out_spec = _aot_export_function(
|
|
|
fn_to_trace,
|
|
|
full_args,
|
|
|
decompositions=decompositions,
|
|
|
num_params_buffers=params_len,
|
|
|
no_tangents=True,
|
|
|
pre_dispatch=pre_dispatch,
|
|
|
dynamic_shapes=dynamic_shapes,
|
|
|
kwargs=kwargs,
|
|
|
)
|
|
|
if trace_joint:
|
|
|
|
|
|
@wraps(functional_call)
|
|
|
def flattened_joint(*args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fake_tangents = [
|
|
|
None
|
|
|
for _ in range(
|
|
|
metadata.num_outputs + metadata.num_mutated_inp_runtime_indices
|
|
|
)
|
|
|
]
|
|
|
fw_outs, gradients = fx_g(args, fake_tangents)
|
|
|
assert len(gradients) == len(args)
|
|
|
output_gradients = []
|
|
|
for a, grad in zip(args, gradients):
|
|
|
if isinstance(a, torch.Tensor) and a.requires_grad:
|
|
|
assert (
|
|
|
grad is not None
|
|
|
), """\
|
|
|
Found a parameter that did not receive a gradient.
|
|
|
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
|
|
|
https://github.com/pytorch/pytorch/issues/101192
|
|
|
"""
|
|
|
output_gradients.append(grad)
|
|
|
else:
|
|
|
assert grad is None
|
|
|
return *fw_outs, *output_gradients
|
|
|
|
|
|
fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args)
|
|
|
|
|
|
user_args_flat = pytree.arg_tree_leaves(*args, **kwargs)
|
|
|
return fx_g, create_graph_signature(
|
|
|
fx_g,
|
|
|
metadata,
|
|
|
in_spec,
|
|
|
out_spec,
|
|
|
user_args_flat=user_args_flat,
|
|
|
params_and_buffers_flat=params_and_buffers_flat,
|
|
|
param_names=list(named_parameters.keys()),
|
|
|
buffer_names=list(named_buffers.keys()),
|
|
|
trace_joint=trace_joint,
|
|
|
num_user_fw_outs=num_fw_outs,
|
|
|
loss_index=output_loss_index,
|
|
|
)
|
|
|
|
|
|
|
|
|
def aot_export_joint_simple(
|
|
|
func: Callable,
|
|
|
args,
|
|
|
*,
|
|
|
trace_joint: bool,
|
|
|
|
|
|
|
|
|
|
|
|
num_params_buffers: int = 0,
|
|
|
decompositions: Optional[dict] = None,
|
|
|
) -> torch.fx.GraphModule:
|
|
|
"""
|
|
|
A simplified version of export. Used by higher order operators.
|
|
|
|
|
|
This function makes a high-level "no calling convention changes" guarantee:
|
|
|
- If no inputs require grad (so we export an inference graph),
|
|
|
there are *no* calling convention change between the exported graph, and "func".
|
|
|
- If at least one input requires grad (so we trace out and export a joint fw-bw graph),
|
|
|
Then if you were partition the graph into a separate forward and backward graph,
|
|
|
The forward graph will have no calling convention changes compared to "func".
|
|
|
|
|
|
The above also relies on some strong restrictions around which functions this API accepts:
|
|
|
(1) `args` cannot contain any pytrees (they must have been pytree_flattened already)
|
|
|
(2) `func` cannot mutate any inputs
|
|
|
(3) The outputs of `func` cannot alias any inputs.
|
|
|
|
|
|
Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
|
|
|
"""
|
|
|
if trace_joint:
|
|
|
ctx = nullcontext
|
|
|
else:
|
|
|
|
|
|
ctx = torch.no_grad
|
|
|
|
|
|
with ctx():
|
|
|
fx_g, metadata, in_spec, out_spec = _aot_export_function(
|
|
|
func,
|
|
|
args,
|
|
|
decompositions=decompositions,
|
|
|
)
|
|
|
in_spec, _kw_in_spec = in_spec.children_specs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
|
|
|
!= 0
|
|
|
):
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
|
|
|
)
|
|
|
|
|
|
if (
|
|
|
len([x for x in metadata.output_info if x.output_type != OutputType.non_alias])
|
|
|
!= 0
|
|
|
):
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}"
|
|
|
)
|
|
|
|
|
|
if in_spec.is_leaf():
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
|
|
|
)
|
|
|
if not all(child.is_leaf() for child in in_spec.children_specs):
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
|
|
|
)
|
|
|
if out_spec.is_leaf():
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
|
|
|
)
|
|
|
if not all(child.is_leaf() for child in out_spec.children_specs):
|
|
|
raise RuntimeError(
|
|
|
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if config.debug_assert:
|
|
|
|
|
|
fw_module, _bw_module = aot_config.default_partition(
|
|
|
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
|
|
|
)
|
|
|
|
|
|
fake_mode = detect_fake_mode(args)
|
|
|
if fake_mode is None:
|
|
|
fake_mode = FakeTensorMode()
|
|
|
with fake_mode:
|
|
|
fw_module(*args)
|
|
|
return fx_g
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _aot_export_function(
|
|
|
func: Callable,
|
|
|
args,
|
|
|
*,
|
|
|
num_params_buffers: int = 0,
|
|
|
decompositions: Optional[dict] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
no_tangents: bool = False,
|
|
|
pre_dispatch: bool = False,
|
|
|
|
|
|
dynamic_shapes: Optional[bool] = None,
|
|
|
kwargs=None,
|
|
|
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
|
|
|
kwargs = kwargs or {}
|
|
|
|
|
|
flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs)
|
|
|
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
|
|
|
|
|
fake_mode = None
|
|
|
if dynamic_shapes is None:
|
|
|
|
|
|
fake_mode = detect_fake_mode(flat_args)
|
|
|
if (
|
|
|
fake_mode is None
|
|
|
and hasattr(func, "_orig_mod")
|
|
|
and isinstance(func._orig_mod, torch.fx.GraphModule)
|
|
|
):
|
|
|
vals = [
|
|
|
node.meta["val"]
|
|
|
for node in func._orig_mod.graph.nodes
|
|
|
if "val" in node.meta
|
|
|
]
|
|
|
fake_mode = detect_fake_mode(vals)
|
|
|
dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aot_config = AOTConfig(
|
|
|
fw_compiler=None,
|
|
|
bw_compiler=None,
|
|
|
inference_compiler=None,
|
|
|
partition_fn=None,
|
|
|
decompositions=decompositions,
|
|
|
num_params_buffers=num_params_buffers,
|
|
|
aot_id=next(AOT_COUNTER),
|
|
|
|
|
|
|
|
|
|
|
|
keep_inference_input_mutations=False,
|
|
|
dynamic_shapes=dynamic_shapes,
|
|
|
aot_autograd_arg_pos_to_source=None,
|
|
|
is_export=True,
|
|
|
no_tangents=no_tangents,
|
|
|
pre_dispatch=pre_dispatch,
|
|
|
)
|
|
|
if fake_mode is None:
|
|
|
fake_mode, shape_env = construct_fake_mode(flat_args, aot_config)
|
|
|
else:
|
|
|
shape_env = fake_mode.shape_env
|
|
|
fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env)
|
|
|
|
|
|
fx_g, meta = create_aot_dispatcher_function(
|
|
|
flat_fn,
|
|
|
fake_flat_args,
|
|
|
aot_config,
|
|
|
fake_mode,
|
|
|
shape_env,
|
|
|
)
|
|
|
return fx_g, meta, in_spec, out_spec.spec
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def _detect_attribute_assignment(mod: torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
NN_MODULE_STD_ATTRS = [
|
|
|
"_backward_hooks",
|
|
|
"_backward_pre_hooks",
|
|
|
"_buffers",
|
|
|
"_forward_hooks",
|
|
|
"_forward_hooks_always_called",
|
|
|
"_forward_hooks_with_kwargs",
|
|
|
"_forward_pre_hooks",
|
|
|
"_forward_pre_hooks_with_kwargs",
|
|
|
"_is_full_backward_hook",
|
|
|
"_load_state_dict_post_hooks",
|
|
|
"_load_state_dict_pre_hooks",
|
|
|
"_modules",
|
|
|
"_non_persistent_buffers_set",
|
|
|
"_parameters",
|
|
|
"_state_dict_hooks",
|
|
|
"_state_dict_pre_hooks",
|
|
|
"training",
|
|
|
]
|
|
|
NN_MODULE_LAZY_STD_ATTRS = [
|
|
|
"_initialize_hook",
|
|
|
"_load_hook",
|
|
|
]
|
|
|
STD_ATTRS = {
|
|
|
*NN_MODULE_STD_ATTRS,
|
|
|
*NN_MODULE_LAZY_STD_ATTRS,
|
|
|
}
|
|
|
|
|
|
def _get_attributes(mod):
|
|
|
|
|
|
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
|
|
|
|
|
|
|
|
|
snapshot = pytree.tree_map(
|
|
|
lambda x: x,
|
|
|
_get_attributes(mod),
|
|
|
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
|
|
|
)
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
|
|
|
|
|
|
assigned_tensor_attributes = []
|
|
|
|
|
|
def _collect_assigned_tensor_attributes(kp, v, _v):
|
|
|
if _v is not v:
|
|
|
attr, *rest = kp
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
assigned_tensor_attributes.append(
|
|
|
f"self.{attr.key}{pytree.keystr(rest)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
new_attrs = _get_attributes(mod)
|
|
|
if len(new_attrs) != len(snapshot):
|
|
|
added_attrs = new_attrs.keys() - snapshot.keys()
|
|
|
deleted_attrs = snapshot.keys() - new_attrs.keys()
|
|
|
|
|
|
if len(added_attrs) > 0:
|
|
|
raise ValueError(
|
|
|
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
|
|
|
f"Such attributes must be registered as buffers using the `register_buffer` "
|
|
|
f"API and must be initialized at model.__init__ "
|
|
|
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
|
|
)
|
|
|
|
|
|
if len(deleted_attrs) > 0:
|
|
|
raise ValueError(
|
|
|
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
|
|
|
f"Such attributes must be registered as buffers using the `register_buffer` "
|
|
|
f"API and must be initialized at model.__init__ "
|
|
|
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
|
|
)
|
|
|
|
|
|
pytree.tree_map_with_path(
|
|
|
_collect_assigned_tensor_attributes, snapshot, new_attrs
|
|
|
)
|
|
|
|
|
|
mod.__dict__.update(snapshot)
|
|
|
|
|
|
if assigned_tensor_attributes:
|
|
|
if len(assigned_tensor_attributes) > 1:
|
|
|
noun, verb = "attributes", "were"
|
|
|
else:
|
|
|
noun, verb = "attribute", "was"
|
|
|
raise ValueError(
|
|
|
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
|
|
|
"Such attributes must be registered as buffers using the `register_buffer` API "
|
|
|
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
|
|
)
|
|
|
|
|
|
|
|
|
compiled_function = aot_function
|
|
|
compiled_module = aot_module
|
|
|
|