|
|
import collections |
|
|
import contextlib |
|
|
import copy |
|
|
import dataclasses |
|
|
import functools |
|
|
import io |
|
|
import itertools |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import os.path |
|
|
import pickle |
|
|
import pstats |
|
|
import shutil |
|
|
import traceback |
|
|
from collections.abc import Iterator, Sequence |
|
|
from typing import Any, Callable, IO, Optional, Union |
|
|
from unittest.mock import patch |
|
|
|
|
|
import torch |
|
|
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled |
|
|
from torch import fx as fx |
|
|
from torch._dynamo.repro.after_aot import save_graph_repro |
|
|
from torch._dynamo.utils import get_debug_dir |
|
|
from torch._inductor import utils |
|
|
from torch._logging import getArtifactLogger |
|
|
from torch._logging._internal import trace_structured |
|
|
from torch._utils_internal import signpost_event |
|
|
from torch.fx.graph_module import GraphModule |
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata |
|
|
from torch.fx.passes.tools_common import legalize_graph |
|
|
from torch.types import FileLike |
|
|
from torch.utils._ordered_set import OrderedSet |
|
|
from torch.utils._pytree import tree_map |
|
|
|
|
|
from . import config, ir |
|
|
from .ir import ExternKernel |
|
|
from .scheduler import ( |
|
|
BaseSchedulerNode, |
|
|
FusedSchedulerNode, |
|
|
NopKernelSchedulerNode, |
|
|
OutputNode, |
|
|
SchedulerNode, |
|
|
) |
|
|
from .virtualized import V |
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
GRAPH_EXECUTION_ORDER: Optional[list[dict[str, object]]] = None |
|
|
RECORD_GRAPH_EXECUTION: bool = False |
|
|
GRAPH_COMPILE_IDS: Optional[dict[int, Optional[str]]] = None |
|
|
|
|
|
ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion") |
|
|
ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion") |
|
|
SchedulerNodeList = list[Any] |
|
|
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) |
|
|
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] |
|
|
|
|
|
|
|
|
@functools.cache |
|
|
def has_dot() -> bool: |
|
|
return shutil.which("dot") is not None |
|
|
|
|
|
|
|
|
def draw_buffers( |
|
|
nodes: list[BaseSchedulerNode], |
|
|
print_graph: bool = False, |
|
|
fname: Optional[str] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Draw a graph in fname.svg. |
|
|
""" |
|
|
if not has_dot(): |
|
|
log.warning("draw_buffers() requires `graphviz` package") |
|
|
return |
|
|
|
|
|
if fname is None: |
|
|
fname = get_graph_being_compiled() |
|
|
|
|
|
graph = create_fx_from_snodes(nodes) |
|
|
|
|
|
for node in graph.nodes: |
|
|
if "fusion_meta" not in node.meta: |
|
|
continue |
|
|
group = node.meta["fusion_meta"].group |
|
|
if isinstance(group, tuple): |
|
|
if isinstance(group[1], int): |
|
|
group = (group[1],) |
|
|
else: |
|
|
group = group[1] |
|
|
|
|
|
|
|
|
dtype = None |
|
|
if isinstance(node, ir.ComputedBuffer): |
|
|
dtype = node.data.dtype |
|
|
|
|
|
metadata = TensorMetadata(group, dtype, None, None, None, None, None) |
|
|
node.meta["tensor_meta"] = metadata |
|
|
|
|
|
if print_graph: |
|
|
print(graph) |
|
|
|
|
|
gm = GraphModule({}, graph) |
|
|
legalize_graph(gm) |
|
|
gm.graph.lint() |
|
|
draw_graph( |
|
|
gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape |
|
|
) |
|
|
|
|
|
|
|
|
def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph: |
|
|
""" |
|
|
Creates a FX Graph from a list of SchedulerNode objects. |
|
|
""" |
|
|
|
|
|
def get_fake_func(name: str) -> Callable[..., int]: |
|
|
def func1(*args: Any) -> int: |
|
|
return 0 |
|
|
|
|
|
func1.__name__ = name |
|
|
return func1 |
|
|
|
|
|
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) |
|
|
|
|
|
buf_to_fx_node = {} |
|
|
node_to_fx_node = {} |
|
|
graph = torch.fx.Graph() |
|
|
first_node = None |
|
|
|
|
|
outputs = [] |
|
|
group: Any = None |
|
|
|
|
|
for snode in snodes: |
|
|
if snode.is_extern(): |
|
|
node_type = "extern" |
|
|
group = node_type |
|
|
elif snode.is_template(): |
|
|
node_type = "template" |
|
|
group = node_type |
|
|
elif isinstance(snode, NopKernelSchedulerNode): |
|
|
node_type = "nop" |
|
|
group = node_type |
|
|
elif isinstance(snode, SchedulerNode): |
|
|
node_type = "compute" |
|
|
group = snode.group |
|
|
elif isinstance(snode, FusedSchedulerNode): |
|
|
node_type = "fused" |
|
|
group = snode.group |
|
|
else: |
|
|
raise RuntimeError("Unknown node type") |
|
|
|
|
|
fused_name = torch._inductor.utils.get_fused_kernel_name( |
|
|
snode.get_nodes(), "original_aten" |
|
|
) |
|
|
func_name = f"{node_type}: {fused_name}" |
|
|
node_func = get_fake_func(func_name) |
|
|
kwargs = {} |
|
|
if hasattr(snode, "get_device"): |
|
|
kwargs = {"device": snode.get_device()} |
|
|
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) |
|
|
|
|
|
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: |
|
|
if isinstance(snode, FusedSchedulerNode): |
|
|
return any(in_output(x) for x in snode.snodes) |
|
|
return any( |
|
|
isinstance(user.node, OutputNode) |
|
|
for buf in snode.get_outputs() |
|
|
for user in buf.users |
|
|
) |
|
|
|
|
|
if in_output(snode): |
|
|
outputs.append(fx_node) |
|
|
name = snode.get_name() |
|
|
fx_node.name = name |
|
|
|
|
|
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) |
|
|
|
|
|
node_to_fx_node[name] = fx_node |
|
|
for buf in snode.get_outputs(): |
|
|
buf_to_fx_node[buf.get_name()] = fx_node |
|
|
|
|
|
if first_node is None: |
|
|
first_node = fx_node |
|
|
|
|
|
|
|
|
for snode in snodes: |
|
|
name = snode.get_name() |
|
|
deps = snode.read_writes.reads |
|
|
|
|
|
fx_node = node_to_fx_node[name] |
|
|
new_args = [] |
|
|
for dep in deps: |
|
|
if dep.name in buf_to_fx_node: |
|
|
dep_node = buf_to_fx_node[dep.name] |
|
|
else: |
|
|
with graph.inserting_before(first_node): |
|
|
dep_node = graph.placeholder(dep.name) |
|
|
buf_to_fx_node[dep.name] = dep_node |
|
|
if dep_node == fx_node: |
|
|
continue |
|
|
new_args.append(dep_node) |
|
|
|
|
|
fx_node.args = tuple(new_args) |
|
|
|
|
|
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) |
|
|
return graph |
|
|
|
|
|
|
|
|
def update_orig_fx_node_name_to_buf_name( |
|
|
nodes: Optional[SchedulerNodeList], |
|
|
node_name_to_buf_name: dict[str, str], |
|
|
parent_buf_name: Optional[str] = None, |
|
|
n_origins: int = 0, |
|
|
) -> None: |
|
|
if nodes is None: |
|
|
return |
|
|
for node in nodes: |
|
|
|
|
|
buf_name = node.get_name() |
|
|
children_nodes = node.get_nodes() |
|
|
if children_nodes is not None and len(children_nodes) > 1: |
|
|
update_orig_fx_node_name_to_buf_name( |
|
|
children_nodes, |
|
|
node_name_to_buf_name, |
|
|
buf_name if parent_buf_name is None else parent_buf_name, |
|
|
) |
|
|
continue |
|
|
else: |
|
|
assert len(children_nodes) == 1 and children_nodes[0] == node |
|
|
|
|
|
ir_node = node.node |
|
|
if ir_node is None or ir_node.origins is None: |
|
|
continue |
|
|
for origin in ir_node.origins: |
|
|
node_name = origin.name |
|
|
|
|
|
|
|
|
if node_name not in node_name_to_buf_name: |
|
|
node_name_to_buf_name[node_name] = ( |
|
|
buf_name if parent_buf_name is None else parent_buf_name |
|
|
) |
|
|
|
|
|
|
|
|
def get_node_name_to_buf_meta( |
|
|
node_name_to_buf_name: dict[str, str], |
|
|
) -> dict[str, BufMeta]: |
|
|
buf_name_to_n_node = {} |
|
|
for node_name, buf_name in node_name_to_buf_name.items(): |
|
|
if buf_name not in buf_name_to_n_node: |
|
|
buf_name_to_n_node[buf_name] = OrderedSet([node_name]) |
|
|
else: |
|
|
buf_name_to_n_node[buf_name].add(node_name) |
|
|
|
|
|
node_name_to_buf_meta = {} |
|
|
for node_name, buf_name in node_name_to_buf_name.items(): |
|
|
n_node = len(buf_name_to_n_node[buf_name]) |
|
|
node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) |
|
|
return node_name_to_buf_meta |
|
|
|
|
|
|
|
|
def annotate_orig_fx_with_snodes( |
|
|
gm: torch.fx.GraphModule, |
|
|
snodes: SchedulerNodeList, |
|
|
) -> None: |
|
|
""" |
|
|
Creates a FX Graph from a list of SchedulerNode objects. |
|
|
""" |
|
|
node_name_to_buf_name: dict[str, str] = {} |
|
|
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) |
|
|
if node_name_to_buf_name is None: |
|
|
return |
|
|
node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) |
|
|
for node in gm.graph.nodes: |
|
|
if node.name in node_name_to_buf_meta: |
|
|
node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def enable_aot_logging() -> Iterator[None]: |
|
|
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" |
|
|
|
|
|
import torch._functorch.aot_autograd |
|
|
|
|
|
log = logging.getLogger(torch._functorch.aot_autograd.__name__) |
|
|
|
|
|
stack = contextlib.ExitStack() |
|
|
if not compile_debug: |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
stack.close() |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) |
|
|
|
|
|
path = os.path.join(get_debug_dir(), "torchinductor") |
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
fh = logging.FileHandler( |
|
|
os.path.join( |
|
|
path, |
|
|
f"aot_{get_aot_graph_name()}_debug.log", |
|
|
) |
|
|
) |
|
|
fh.setLevel(logging.DEBUG) |
|
|
fh.setFormatter( |
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") |
|
|
) |
|
|
log.addHandler(fh) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
log.removeHandler(fh) |
|
|
stack.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {} |
|
|
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} |
|
|
_pre_grad_graph_id: Optional[int] = None |
|
|
_inductor_pre_grad_node_stack_trace: dict[str, str] = {} |
|
|
_inductor_kernel_stack_trace: dict[str, list[str]] = {} |
|
|
_inductor_kernel_provenance_debug_handle: int = 0 |
|
|
|
|
|
|
|
|
def reset_inductor_kernel_provenance_debug_handle() -> None: |
|
|
global _inductor_kernel_provenance_debug_handle |
|
|
_inductor_kernel_provenance_debug_handle = 0 |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def reset_provenance_globals() -> Iterator[None]: |
|
|
"""Context manager that resets provenance tracking globals upon entering |
|
|
and restores their original values when exiting.""" |
|
|
global _pre_grad_graph_id |
|
|
global _inductor_post_to_pre_grad_nodes |
|
|
global _inductor_triton_kernel_to_post_grad_node_info |
|
|
global _inductor_pre_grad_node_stack_trace |
|
|
global _inductor_kernel_stack_trace |
|
|
|
|
|
|
|
|
original_pre_grad_graph_id = _pre_grad_graph_id |
|
|
original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy() |
|
|
original_triton_kernel_to_post_grad_node_info = ( |
|
|
_inductor_triton_kernel_to_post_grad_node_info.copy() |
|
|
) |
|
|
original_inductor_pre_grad_node_stack_trace = ( |
|
|
_inductor_pre_grad_node_stack_trace.copy() |
|
|
) |
|
|
original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy() |
|
|
|
|
|
|
|
|
_pre_grad_graph_id = -1 |
|
|
_inductor_post_to_pre_grad_nodes = {} |
|
|
_inductor_triton_kernel_to_post_grad_node_info = {} |
|
|
_inductor_pre_grad_node_stack_trace = {} |
|
|
_inductor_kernel_stack_trace = {} |
|
|
|
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
|
|
|
_pre_grad_graph_id = original_pre_grad_graph_id |
|
|
_inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes |
|
|
_inductor_triton_kernel_to_post_grad_node_info = ( |
|
|
original_triton_kernel_to_post_grad_node_info |
|
|
) |
|
|
_inductor_kernel_stack_trace = original_inductor_kernel_stack_trace |
|
|
_inductor_pre_grad_node_stack_trace = ( |
|
|
original_inductor_pre_grad_node_stack_trace |
|
|
) |
|
|
|
|
|
|
|
|
class DebugContext: |
|
|
_counter = itertools.count() |
|
|
|
|
|
@staticmethod |
|
|
def create_debug_dir(folder_name: str) -> Optional[str]: |
|
|
debug_dir = config.trace.debug_dir or get_debug_dir() |
|
|
for n in DebugContext._counter: |
|
|
dirname = os.path.join( |
|
|
debug_dir, |
|
|
"torchinductor", |
|
|
f"{folder_name}.{n}", |
|
|
) |
|
|
if not os.path.exists(dirname): |
|
|
os.makedirs(dirname) |
|
|
return dirname |
|
|
return None |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self._prof = None |
|
|
self._path = None |
|
|
self._stack = contextlib.ExitStack() |
|
|
|
|
|
def copy(self, new_path: str) -> None: |
|
|
if not self._path: |
|
|
return |
|
|
assert new_path.endswith(".debug"), new_path |
|
|
from filelock import FileLock |
|
|
|
|
|
try: |
|
|
with FileLock(f"{new_path}.lock"): |
|
|
if os.path.exists(new_path): |
|
|
shutil.rmtree(new_path) |
|
|
shutil.copytree(self._path, new_path) |
|
|
except OSError: |
|
|
log.warning( |
|
|
"Failed to copy debug files from %s to %s", self._path, new_path |
|
|
) |
|
|
|
|
|
def fopen( |
|
|
self, |
|
|
filename: str, |
|
|
write_mode: str = "w", |
|
|
*args: Any, |
|
|
**kwargs: Any, |
|
|
) -> IO[Any]: |
|
|
assert self._path |
|
|
return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def fopen_context( |
|
|
self, |
|
|
filename: str, |
|
|
write_mode: str = "w", |
|
|
*args: Any, |
|
|
**kwargs: Any, |
|
|
) -> Iterator[IO[Any]]: |
|
|
assert self._path |
|
|
with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: |
|
|
yield f |
|
|
|
|
|
def filename(self, suffix: str) -> str: |
|
|
assert self._path |
|
|
return os.path.join(self._path, suffix) |
|
|
|
|
|
def upload_tar(self) -> None: |
|
|
if config.trace.upload_tar is not None: |
|
|
import tarfile |
|
|
|
|
|
assert self._path |
|
|
tar_file = os.path.join( |
|
|
self._path, f"{os.path.basename(self._path)}.tar.gz" |
|
|
) |
|
|
with tarfile.open(tar_file, "w:gz") as tar: |
|
|
tar.add(self._path, arcname=os.path.basename(self._path)) |
|
|
config.trace.upload_tar(tar_file) |
|
|
|
|
|
def __enter__(self) -> None: |
|
|
if config.debug: |
|
|
log = logging.getLogger("torch._dynamo") |
|
|
prev_level = log.level |
|
|
log.setLevel(logging.DEBUG) |
|
|
|
|
|
def reset_log_level(level: Any) -> None: |
|
|
log.setLevel(level) |
|
|
|
|
|
self._stack.callback(reset_log_level, prev_level) |
|
|
|
|
|
self._stack.enter_context(V.set_debug_handler(self)) |
|
|
|
|
|
if not config.trace.enabled: |
|
|
return |
|
|
|
|
|
self._path = self.create_debug_dir(get_aot_graph_name()) |
|
|
|
|
|
if config.trace.debug_log: |
|
|
self._setup_log_capture("debug.log", logging.DEBUG) |
|
|
if config.trace.info_log: |
|
|
self._setup_log_capture("info.log", logging.INFO) |
|
|
|
|
|
def _setup_log_capture( |
|
|
self, |
|
|
filename: str, |
|
|
level: int, |
|
|
) -> None: |
|
|
log = logging.getLogger("torch._inductor") |
|
|
fd = self._stack.enter_context(self.fopen(filename)) |
|
|
ch = logging.StreamHandler(fd) |
|
|
ch.setLevel(level) |
|
|
ch.setFormatter( |
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") |
|
|
) |
|
|
log.addHandler(ch) |
|
|
log.setLevel(min(log.level, level)) |
|
|
self._stack.callback(log.removeHandler, ch) |
|
|
|
|
|
def __exit__( |
|
|
self, |
|
|
exc_type: Optional[type[BaseException]], |
|
|
exc_val: Optional[BaseException], |
|
|
exc_tb: Optional[Any], |
|
|
) -> None: |
|
|
if self._prof: |
|
|
self._prof.disable() |
|
|
self._save_profile_data() |
|
|
|
|
|
if self._path: |
|
|
self.upload_tar() |
|
|
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) |
|
|
self._stack.close() |
|
|
|
|
|
def _save_profile_data(self) -> None: |
|
|
assert self._prof |
|
|
self._prof.dump_stats(self.filename("compile.prof")) |
|
|
with self.fopen("compile.stats") as fd: |
|
|
stats = pstats.Stats(self._prof, stream=fd) |
|
|
stats.strip_dirs() |
|
|
stats.sort_stats("cumtime") |
|
|
stats.print_stats(100) |
|
|
stats.sort_stats("tottime") |
|
|
stats.print_stats(100) |
|
|
|
|
|
def __getattr__(self, name: str) -> Optional[Callable[..., None]]: |
|
|
if config.trace.enabled and getattr(config.trace, name): |
|
|
try: |
|
|
return getattr(DebugFormatter(self), name) |
|
|
except Exception: |
|
|
log.warning("Ignoring exception in debug code", exc_info=True) |
|
|
return None |
|
|
else: |
|
|
|
|
|
def ignored(*args: Any, **kwargs: Any) -> None: |
|
|
pass |
|
|
|
|
|
return ignored |
|
|
|
|
|
|
|
|
class DebugFormatter: |
|
|
def __init__(self, handler: DebugContext) -> None: |
|
|
self.fopen = handler.fopen |
|
|
self.fopen_context = handler.fopen_context |
|
|
self.filename = handler.filename |
|
|
self.handler = handler |
|
|
|
|
|
def fx_graph( |
|
|
self, |
|
|
gm: torch.fx.GraphModule, |
|
|
inputs: list[torch.Tensor], |
|
|
) -> None: |
|
|
with self.fopen("fx_graph_runnable.py") as fd: |
|
|
save_dir = None |
|
|
if torch._inductor.config.trace.save_real_tensors: |
|
|
inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs) |
|
|
save_dir = os.path.dirname(fd.name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stable_hash = torch._inductor.config.trace.save_real_tensors |
|
|
with torch._inductor.config.patch( |
|
|
{"trace.enabled": False, "trace.save_real_tensors": False} |
|
|
): |
|
|
save_graph_repro( |
|
|
fd, |
|
|
gm, |
|
|
inputs, |
|
|
"inductor", |
|
|
save_dir=save_dir, |
|
|
stable_hash=stable_hash, |
|
|
) |
|
|
|
|
|
with self.fopen("fx_graph_readable.py") as fd: |
|
|
fd.write(gm.print_readable(print_output=False)) |
|
|
|
|
|
def fx_graph_transformed( |
|
|
self, |
|
|
gm: torch.fx.GraphModule, |
|
|
inputs: list[torch.Tensor], |
|
|
) -> None: |
|
|
with self.fopen("fx_graph_transformed.py") as fd: |
|
|
fd.write(gm.print_readable(print_output=False)) |
|
|
|
|
|
def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: |
|
|
with self.fopen("ir_pre_fusion.txt") as fd: |
|
|
fd.write(self._write_ir(nodes)) |
|
|
|
|
|
def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: |
|
|
with self.fopen("ir_post_fusion.txt") as fd: |
|
|
fd.write(self._write_ir(nodes)) |
|
|
|
|
|
@staticmethod |
|
|
def _write_ir(nodes: SchedulerNodeList) -> str: |
|
|
buf = io.StringIO() |
|
|
for node in nodes: |
|
|
buf.write(node.debug_str()) |
|
|
buf.write("\n\n\n") |
|
|
return buf.getvalue() |
|
|
|
|
|
def graph_diagram(self, nodes: SchedulerNodeList) -> None: |
|
|
draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) |
|
|
|
|
|
def draw_orig_fx_graph( |
|
|
self, |
|
|
gm: torch.fx.GraphModule, |
|
|
nodes: SchedulerNodeList, |
|
|
) -> None: |
|
|
annotate_orig_fx_with_snodes(gm, nodes) |
|
|
draw_graph( |
|
|
gm, |
|
|
fname=self.filename("orig_fx_graph_diagram.svg"), |
|
|
clear_meta=False, |
|
|
prog=GRAPHVIZ_COMMAND_SCALABLE, |
|
|
parse_stack_trace=True, |
|
|
dot_graph_shape=config.trace.dot_graph_shape, |
|
|
) |
|
|
|
|
|
def output_code(self, filename: str, extension: str = "py") -> None: |
|
|
shutil.copy(filename, self.filename(f"output_code.{extension}")) |
|
|
|
|
|
def log_autotuning_results( |
|
|
self, |
|
|
name: str, |
|
|
input_nodes: list[ir.IRNode], |
|
|
timings: dict["ChoiceCaller", float], |
|
|
elapse: float, |
|
|
precompile_elapse: float, |
|
|
prescreening_elapse: Optional[float], |
|
|
) -> None: |
|
|
from .ir import FixedLayout |
|
|
|
|
|
def build_node_info(node: ir.IRNode) -> dict[str, str]: |
|
|
if hasattr(node, "name"): |
|
|
node_name = node.name |
|
|
else: |
|
|
node_name = "" |
|
|
node_info = { |
|
|
"name": node_name, |
|
|
"type": type(node).__name__, |
|
|
} |
|
|
try: |
|
|
layout = node.get_output_spec() |
|
|
if isinstance(layout, FixedLayout): |
|
|
offset = 0 |
|
|
try: |
|
|
offset = int(layout.offset) |
|
|
except Exception: |
|
|
try: |
|
|
offset = V.graph.sizevars.size_hint( |
|
|
layout.offset, fallback=0 |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
static_layout = FixedLayout( |
|
|
layout.device, |
|
|
dtype=layout.dtype, |
|
|
size=[*V.graph.sizevars.size_hints(layout.size)], |
|
|
stride=[*V.graph.sizevars.size_hints(layout.stride)], |
|
|
offset=offset, |
|
|
) |
|
|
node_info["layout"] = str(static_layout) |
|
|
else: |
|
|
node_info["layout"] = str(layout) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
node_info["dtype"] = str(node.get_dtype()) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
node_info["device"] = str(node.get_device()) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
node_info["stride"] = str( |
|
|
V.graph.sizevars.size_hints(node.get_stride()) |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) |
|
|
except Exception: |
|
|
pass |
|
|
if hasattr(node, "data") and isinstance(node.data, ir.IRNode): |
|
|
node_info["data"] = build_node_info(node.data) |
|
|
return node_info |
|
|
|
|
|
general_properties = { |
|
|
"op_name": name, |
|
|
"cuda_device_name": torch.cuda.get_device_name(), |
|
|
"cuda_device_count": torch.cuda.device_count(), |
|
|
"input_nodes": [build_node_info(node) for node in input_nodes], |
|
|
"autotuning_time": elapse, |
|
|
"precompile_time": precompile_elapse, |
|
|
"prescreening_time": prescreening_elapse, |
|
|
} |
|
|
with self.fopen_context( |
|
|
"autotuning_result_json_list.txt", "at", encoding="utf-8" |
|
|
) as fd: |
|
|
for caller, time in timings.items(): |
|
|
info_dict = dict(caller.info_dict()) |
|
|
info_dict.update(general_properties) |
|
|
info_dict["benchmark_result"] = time |
|
|
json.dump(info_dict, fd) |
|
|
fd.write("\n") |
|
|
|
|
|
|
|
|
def log_ir_pre_fusion(nodes: SchedulerNodeList) -> None: |
|
|
if ir_pre_fusion_log.isEnabledFor(logging.INFO): |
|
|
ir_pre_fusion_log.info("BEFORE FUSION\n%s", DebugFormatter._write_ir(nodes)) |
|
|
|
|
|
V.debug.ir_pre_fusion(nodes) |
|
|
|
|
|
|
|
|
def log_ir_post_fusion(nodes: SchedulerNodeList) -> None: |
|
|
if ir_post_fusion_log.isEnabledFor(logging.INFO): |
|
|
ir_post_fusion_log.info("AFTER FUSION\n%s", DebugFormatter._write_ir(nodes)) |
|
|
|
|
|
V.debug.ir_post_fusion(nodes) |
|
|
|
|
|
|
|
|
def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None: |
|
|
try: |
|
|
trace_structured( |
|
|
"artifact", |
|
|
metadata_fn=lambda: { |
|
|
"name": "inductor_collective_schedule", |
|
|
"encoding": "json", |
|
|
}, |
|
|
payload_fn=lambda: schedule, |
|
|
) |
|
|
except Exception: |
|
|
log.debug( |
|
|
"Failed to log inductor_collective_schedule via structured logging", |
|
|
exc_info=True, |
|
|
) |
|
|
|
|
|
|
|
|
def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: |
|
|
schedule = [ |
|
|
getattr(op, "python_kernel_name", None) |
|
|
for node in nodes |
|
|
if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel) |
|
|
] |
|
|
|
|
|
|
|
|
if schedule: |
|
|
_dump_collective_schedule(schedule) |
|
|
|
|
|
|
|
|
def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None: |
|
|
"""Log per-op runtime estimates and output tensor metadata for TLParse.""" |
|
|
|
|
|
try: |
|
|
to_size_hints = V.graph.sizevars.size_hints |
|
|
|
|
|
def to_list(x: Optional[Sequence[Any]]) -> list[Any]: |
|
|
return list(to_size_hints(x)) if x is not None else [] |
|
|
|
|
|
def dtype_to_str(dtype: Any) -> Optional[str]: |
|
|
if dtype is None: |
|
|
return None |
|
|
s = str(dtype) |
|
|
s = s.removeprefix("torch.") |
|
|
return s |
|
|
|
|
|
ops: list[dict[str, Any]] = [] |
|
|
for s, runtime_ns in node_runtimes: |
|
|
name = getattr(s.node, "python_kernel_name", s.get_name()) |
|
|
op_type = "collective" if utils.is_collective(s.node) else "compute" |
|
|
|
|
|
|
|
|
outputs: list[dict[str, Any]] = [] |
|
|
try: |
|
|
for buf in s.get_outputs(): |
|
|
irnode = buf.node |
|
|
shape = irnode.maybe_get_size() |
|
|
stride = ( |
|
|
irnode.get_stride() |
|
|
if isinstance(irnode.layout, ir.Layout) |
|
|
else None |
|
|
) |
|
|
dtype = irnode.maybe_get_dtype() |
|
|
outputs.append( |
|
|
{ |
|
|
"shape": to_list(shape), |
|
|
"stride": to_list(stride), |
|
|
"dtype": dtype_to_str(dtype), |
|
|
} |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
ops.append( |
|
|
{ |
|
|
"name": name, |
|
|
"type": op_type, |
|
|
"estimated_runtime_ns": runtime_ns, |
|
|
"outputs": outputs, |
|
|
} |
|
|
) |
|
|
|
|
|
trace_structured( |
|
|
"artifact", |
|
|
metadata_fn=lambda: { |
|
|
"name": "inductor_runtime_and_tensor_meta", |
|
|
"encoding": "json", |
|
|
}, |
|
|
payload_fn=lambda: {"ops": ops}, |
|
|
) |
|
|
except Exception: |
|
|
log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True) |
|
|
|
|
|
|
|
|
def log_graph_execution() -> None: |
|
|
"""Emit a structured artifact with the graph execution order.""" |
|
|
if not GRAPH_EXECUTION_ORDER: |
|
|
return |
|
|
try: |
|
|
trace_structured( |
|
|
"artifact", |
|
|
metadata_fn=lambda: { |
|
|
"name": "graph_execution", |
|
|
"encoding": "json", |
|
|
}, |
|
|
payload_fn=lambda: {"graph_execution_order": GRAPH_EXECUTION_ORDER}, |
|
|
) |
|
|
except Exception: |
|
|
log.debug("Failed to log graph_execution", exc_info=True) |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def record_and_log_graph_execution_order() -> Iterator[None]: |
|
|
"""Record graph execution order and log it once on exit.""" |
|
|
global RECORD_GRAPH_EXECUTION, GRAPH_EXECUTION_ORDER, GRAPH_COMPILE_IDS |
|
|
GRAPH_EXECUTION_ORDER = [] |
|
|
GRAPH_COMPILE_IDS = {} |
|
|
RECORD_GRAPH_EXECUTION = True |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
log_graph_execution() |
|
|
RECORD_GRAPH_EXECUTION = False |
|
|
GRAPH_EXECUTION_ORDER = None |
|
|
GRAPH_COMPILE_IDS = None |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class TensorMetadataHolder: |
|
|
tensor_metadata: TensorMetadata |
|
|
device: torch.device |
|
|
|
|
|
|
|
|
save_args_cnt = itertools.count() |
|
|
|
|
|
|
|
|
def create_mapping_pre_post_grad_nodes( |
|
|
pre_grad_graph_id: Optional[int], |
|
|
post_to_pre_grad_nodes_json: dict[str, Any], |
|
|
) -> dict[str, dict[str, list[str]]]: |
|
|
""" |
|
|
Create bidirectional mappings between pre_grad graph nodes |
|
|
and post_grad graph code nodes, and vice versa. |
|
|
""" |
|
|
|
|
|
empty_return: dict[str, dict[str, list[str]]] = { |
|
|
"preToPost": {}, |
|
|
"postToPre": {}, |
|
|
} |
|
|
|
|
|
if not isinstance(post_to_pre_grad_nodes_json, dict): |
|
|
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") |
|
|
return empty_return |
|
|
|
|
|
if not isinstance(pre_grad_graph_id, int): |
|
|
|
|
|
|
|
|
return empty_return |
|
|
|
|
|
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) |
|
|
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) |
|
|
|
|
|
try: |
|
|
|
|
|
def check_format(node: dict[str, Any]) -> bool: |
|
|
if not isinstance(node, dict): |
|
|
log.error( |
|
|
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict" |
|
|
) |
|
|
return False |
|
|
if "graph_id" not in node or "name" not in node or "from_node" not in node: |
|
|
log.error( |
|
|
"Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format" |
|
|
) |
|
|
return False |
|
|
return True |
|
|
|
|
|
for outer_key, node_array in post_to_pre_grad_nodes_json.items(): |
|
|
if not isinstance(node_array, list): |
|
|
log.error( |
|
|
"Provenance tacking error: post_to_pre_grad_nodes_json value is not a list" |
|
|
) |
|
|
return empty_return |
|
|
for node in node_array: |
|
|
if not check_format(node): |
|
|
return empty_return |
|
|
|
|
|
if node.get("graph_id") == pre_grad_graph_id: |
|
|
pre_to_post[node["name"]].add(outer_key) |
|
|
post_to_pre[outer_key].add(node["name"]) |
|
|
|
|
|
|
|
|
stack = [(n, outer_key) for n in node.get("from_node", [])] |
|
|
while stack: |
|
|
current_node, parent_key = stack.pop() |
|
|
if not check_format(current_node): |
|
|
return empty_return |
|
|
if current_node.get("graph_id") == pre_grad_graph_id: |
|
|
pre_to_post[current_node["name"]].add(parent_key) |
|
|
post_to_pre[parent_key].add(current_node["name"]) |
|
|
stack.extend( |
|
|
(n, parent_key) for n in current_node.get("from_node", []) |
|
|
) |
|
|
|
|
|
def convert_sets_to_lists(d: dict[str, Any]) -> None: |
|
|
for key in d: |
|
|
d[key] = list(d[key]) |
|
|
d = dict(d) |
|
|
|
|
|
|
|
|
convert_sets_to_lists(pre_to_post) |
|
|
convert_sets_to_lists(post_to_pre) |
|
|
return { |
|
|
"preToPost": pre_to_post, |
|
|
"postToPre": post_to_pre, |
|
|
} |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
signpost_event( |
|
|
"inductor", |
|
|
"provenance_tracking_error", |
|
|
{ |
|
|
"function": "create_mapping_pre_post_grad_nodes", |
|
|
"error_msg": str(e), |
|
|
"stack_trace": traceback.format_exc(), |
|
|
}, |
|
|
) |
|
|
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) |
|
|
log.error("pre_grad_graph_id: %s", pre_grad_graph_id) |
|
|
return empty_return |
|
|
|
|
|
|
|
|
def create_node_mapping_kernel_to_post_grad( |
|
|
triton_kernel_to_post_grad_json: dict[str, Any], |
|
|
) -> dict[str, dict[str, Any]]: |
|
|
"""Create bidirectional mappings between triton kernel name and post_grad |
|
|
graph code nodes, and vice versa. |
|
|
""" |
|
|
|
|
|
|
|
|
empty_return: dict[str, dict[str, Any]] = { |
|
|
"cppCodeToPost": {}, |
|
|
"postToCppCode": {}, |
|
|
} |
|
|
|
|
|
if not isinstance(triton_kernel_to_post_grad_json, dict): |
|
|
log.error( |
|
|
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" |
|
|
) |
|
|
return empty_return |
|
|
|
|
|
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) |
|
|
|
|
|
try: |
|
|
for outer_key, node_array in triton_kernel_to_post_grad_json.items(): |
|
|
if not isinstance(node_array, list): |
|
|
log.error( |
|
|
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" |
|
|
) |
|
|
return empty_return |
|
|
for curr_node in node_array: |
|
|
post_to_cpp_code[curr_node].add(outer_key) |
|
|
|
|
|
def convert_sets_to_lists(d: dict[str, Any]) -> None: |
|
|
for key in d: |
|
|
d[key] = list(d[key]) |
|
|
d = dict(d) |
|
|
|
|
|
|
|
|
convert_sets_to_lists(post_to_cpp_code) |
|
|
return { |
|
|
"cppCodeToPost": triton_kernel_to_post_grad_json, |
|
|
"postToCppCode": post_to_cpp_code, |
|
|
} |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
signpost_event( |
|
|
"inductor", |
|
|
"provenance_tracking_error", |
|
|
{ |
|
|
"function": "create_mapping_kernel_to_post_grad", |
|
|
"error_msg": str(e), |
|
|
"stack_trace": traceback.format_exc(), |
|
|
}, |
|
|
) |
|
|
log.error( |
|
|
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json |
|
|
) |
|
|
return empty_return |
|
|
|
|
|
|
|
|
def dump_inductor_provenance_info() -> dict[str, Any]: |
|
|
try: |
|
|
global _pre_grad_graph_id |
|
|
global _inductor_post_to_pre_grad_nodes |
|
|
global _inductor_triton_kernel_to_post_grad_node_info |
|
|
node_mapping: dict[str, Any] = {} |
|
|
if _pre_grad_graph_id: |
|
|
node_mapping_kernel = create_node_mapping_kernel_to_post_grad( |
|
|
_inductor_triton_kernel_to_post_grad_node_info |
|
|
) |
|
|
node_mapping = { |
|
|
**_inductor_post_to_pre_grad_nodes, |
|
|
**node_mapping_kernel, |
|
|
} |
|
|
if config.trace.enabled: |
|
|
with V.debug.fopen( |
|
|
"inductor_provenance_tracking_node_mappings.json", "w" |
|
|
) as fd: |
|
|
json.dump(node_mapping, fd) |
|
|
|
|
|
|
|
|
node_mapping["version"] = 2.0 |
|
|
return node_mapping |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
signpost_event( |
|
|
"inductor", |
|
|
"provenance_tracking_error", |
|
|
{ |
|
|
"function": "dump_inductor_provenance_info", |
|
|
"error_msg": str(e), |
|
|
"stack_trace": traceback.format_exc(), |
|
|
}, |
|
|
) |
|
|
return {} |
|
|
|
|
|
|
|
|
def create_kernel_information_json() -> dict[str, dict[str, list[str]]]: |
|
|
"""Create kernel information JSON""" |
|
|
try: |
|
|
global _inductor_post_to_pre_grad_nodes |
|
|
global _inductor_kernel_stack_trace |
|
|
global _inductor_triton_kernel_to_post_grad_node_info |
|
|
|
|
|
post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {}) |
|
|
all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet( |
|
|
_inductor_triton_kernel_to_post_grad_node_info.keys() |
|
|
) |
|
|
|
|
|
result = {} |
|
|
for kernel_name in all_kernels: |
|
|
post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get( |
|
|
kernel_name, [] |
|
|
) |
|
|
|
|
|
pre_grad_nodes: OrderedSet[str] = OrderedSet() |
|
|
for post_node in post_grad_nodes: |
|
|
pre_grad_nodes.update(post_to_pre.get(post_node, [])) |
|
|
|
|
|
result[kernel_name] = { |
|
|
"stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []), |
|
|
"post_grad_nodes": post_grad_nodes, |
|
|
"pre_grad_nodes": list(pre_grad_nodes), |
|
|
} |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
signpost_event( |
|
|
"inductor", |
|
|
"provenance_tracking_error", |
|
|
{ |
|
|
"function": "create_kernel_information_json", |
|
|
"error_msg": str(e), |
|
|
"stack_trace": traceback.format_exc(), |
|
|
}, |
|
|
) |
|
|
return {} |
|
|
|
|
|
|
|
|
def set_kernel_post_grad_provenance_tracing( |
|
|
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], |
|
|
kernel_name: str, |
|
|
is_extern: bool = False, |
|
|
) -> Optional[int]: |
|
|
""" |
|
|
Set the mapping between `kernel_name` and the post_grad nodes in `node_schedule`. |
|
|
|
|
|
Returns a unique int debug handler for each call to this function. |
|
|
""" |
|
|
|
|
|
try: |
|
|
from .codegen.simd_kernel_features import DisableReduction, EnableReduction |
|
|
|
|
|
global _inductor_triton_kernel_to_post_grad_node_info |
|
|
global _inductor_kernel_stack_trace |
|
|
global _inductor_kernel_provenance_debug_handle |
|
|
|
|
|
_inductor_kernel_provenance_debug_handle += 1 |
|
|
stack_traces: list[str] = [] |
|
|
kernel_name = f"{kernel_name}:{_inductor_kernel_provenance_debug_handle}" |
|
|
if is_extern: |
|
|
assert isinstance(node_schedule, ExternKernel) |
|
|
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( |
|
|
kernel_name, [] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if node_schedule.origin_node: |
|
|
origin_node_name = node_schedule.origin_node.name |
|
|
if origin_node_name not in curr_node_info: |
|
|
curr_node_info.append(origin_node_name) |
|
|
else: |
|
|
curr_node_info.extend( |
|
|
origin.name |
|
|
for origin in node_schedule.origins |
|
|
if origin.name not in curr_node_info |
|
|
) |
|
|
stack_traces = list(node_schedule.get_stack_traces()) |
|
|
else: |
|
|
assert isinstance(node_schedule, list) |
|
|
stack_traces_set: OrderedSet[str] = OrderedSet() |
|
|
for snode in node_schedule: |
|
|
if snode not in (EnableReduction, DisableReduction): |
|
|
if snode.node is not None: |
|
|
curr_node_info = ( |
|
|
_inductor_triton_kernel_to_post_grad_node_info.setdefault( |
|
|
kernel_name, [] |
|
|
) |
|
|
) |
|
|
stack_traces_set.update(snode.node.get_stack_traces()) |
|
|
curr_node_info.extend( |
|
|
origin.name |
|
|
for origin in snode.node.origins |
|
|
if origin.name not in curr_node_info |
|
|
) |
|
|
stack_traces = list(stack_traces_set) |
|
|
_inductor_kernel_stack_trace.setdefault(kernel_name, []).extend(stack_traces) |
|
|
return _inductor_kernel_provenance_debug_handle |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
signpost_event( |
|
|
"inductor", |
|
|
"provenance_tracking_error", |
|
|
{ |
|
|
"function": "set_kernel_post_grad_provenance_tracing", |
|
|
"error_msg": str(e), |
|
|
"stack_trace": traceback.format_exc(), |
|
|
}, |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: |
|
|
""" |
|
|
This function is used to save arguments for a compile_fx_inner function call |
|
|
to the file system. Later on one can replay the compile_fx_inner call |
|
|
with the saved arguments using load_args_and_run_compile_fx_inner. |
|
|
""" |
|
|
|
|
|
folder = "/tmp/inductor_saved_args" |
|
|
if not os.path.exists(folder): |
|
|
os.mkdir(folder) |
|
|
|
|
|
def handle_tensor(x: Any) -> Any: |
|
|
""" |
|
|
Pickle FakeTensor will result in error: |
|
|
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove' |
|
|
|
|
|
Convert all Tensor to metadata. This may also makes pickle faster. |
|
|
""" |
|
|
if isinstance(x, torch.Tensor): |
|
|
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) |
|
|
else: |
|
|
return x |
|
|
|
|
|
args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) |
|
|
|
|
|
fn_name = "compile_fx_inner" |
|
|
path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" |
|
|
with open(path, "wb") as f: |
|
|
pickle.dump((args_to_save, kwargs_to_save), f) |
|
|
|
|
|
if log.isEnabledFor(logging.DEBUG): |
|
|
message = f""" |
|
|
Arguments for a compile_fx_inner call is saved to {path}. To replay the call, |
|
|
run the following: |
|
|
|
|
|
from torch._inductor.debug import load_args_and_run_compile_fx_inner |
|
|
load_args_and_run_compile_fx_inner({path!r}) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(message) |
|
|
|
|
|
|
|
|
def load_args_and_run_compile_fx_inner(path: str) -> Any: |
|
|
from torch._inductor.compile_fx import compile_fx_inner |
|
|
|
|
|
with open(path, "rb") as f: |
|
|
args, kwargs = pickle.load(f) |
|
|
|
|
|
def handle_tensor(x: Any) -> Any: |
|
|
if isinstance(x, TensorMetadataHolder): |
|
|
return torch._dynamo.testing.rand_strided( |
|
|
x.tensor_metadata.shape, |
|
|
x.tensor_metadata.stride, |
|
|
x.tensor_metadata.dtype, |
|
|
x.device, |
|
|
) |
|
|
else: |
|
|
return x |
|
|
|
|
|
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) |
|
|
with fake_mode, config.patch("save_args", False): |
|
|
args, kwargs = tree_map(handle_tensor, (args, kwargs)) |
|
|
return compile_fx_inner(*args, **kwargs) |
|
|
|
|
|
|
|
|
def aot_inductor_minifier_wrapper( |
|
|
func: Callable[..., str], |
|
|
exported_program: torch.export.ExportedProgram, |
|
|
*, |
|
|
inductor_configs: dict[str, Any], |
|
|
package_path: Optional[FileLike] = None, |
|
|
) -> str: |
|
|
from torch._dynamo.debug_utils import AccuracyError |
|
|
from torch._dynamo.repro.aoti import dump_to_minify |
|
|
from torch._inductor import config |
|
|
from torch._inductor.compile_fx import _aoti_flatten_inputs |
|
|
|
|
|
use_minifier = config.aot_inductor.dump_aoti_minifier |
|
|
|
|
|
gm = exported_program.module(check_guards=False) |
|
|
assert isinstance(gm, torch.fx.GraphModule) |
|
|
|
|
|
args, kwargs = exported_program.example_inputs |
|
|
|
|
|
try: |
|
|
if use_minifier and config.aot_inductor.repro_level == 3: |
|
|
|
|
|
dump_to_minify( |
|
|
exported_program, |
|
|
"aot_inductor", |
|
|
options=inductor_configs, |
|
|
) |
|
|
if use_minifier and config.aot_inductor.repro_level == 4: |
|
|
|
|
|
|
|
|
|
|
|
gm_copy = copy.deepcopy(gm) |
|
|
example_inputs_copy = copy.deepcopy(exported_program.example_inputs) |
|
|
config_copy = copy.deepcopy(inductor_configs) |
|
|
flat_example_inputs, config_copy = _aoti_flatten_inputs( |
|
|
gm_copy, |
|
|
example_inputs_copy[0], |
|
|
example_inputs_copy[1], |
|
|
options=config_copy, |
|
|
) |
|
|
tuple_inputs = tuple(flat_example_inputs) |
|
|
flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False) |
|
|
func( |
|
|
flattened_ep.module(check_guards=False), |
|
|
tuple_inputs, |
|
|
inductor_configs=config_copy, |
|
|
package_path=package_path, |
|
|
load_and_run=True, |
|
|
check_accuracy="accuracy", |
|
|
) |
|
|
|
|
|
return func( |
|
|
gm, |
|
|
args, |
|
|
kwargs, |
|
|
inductor_configs=inductor_configs, |
|
|
package_path=package_path, |
|
|
load_and_run=use_minifier, |
|
|
) |
|
|
except AccuracyError as e: |
|
|
dump_to_minify( |
|
|
exported_program, |
|
|
"aot_inductor_accuracy", |
|
|
command="minify", |
|
|
options=inductor_configs, |
|
|
) |
|
|
log.warning("Accuracy failed") |
|
|
raise e |
|
|
except Exception as e: |
|
|
if use_minifier: |
|
|
command = "minify" |
|
|
|
|
|
if config.aot_inductor.repro_level == 1: |
|
|
command = "run" |
|
|
|
|
|
dump_to_minify( |
|
|
exported_program, |
|
|
"aot_inductor", |
|
|
command=command, |
|
|
options=inductor_configs, |
|
|
) |
|
|
raise e |
|
|
|