|
|
""" |
|
|
Debug utilities for TorchDynamo compilation and execution. |
|
|
|
|
|
This module provides various debugging tools and utilities for TorchDynamo, including: |
|
|
|
|
|
- Minification support for reducing test cases while preserving bugs |
|
|
- Input/output handling via InputReader and InputWriter for reproducible testing |
|
|
- Accuracy checking between original and compiled models |
|
|
- Neural network module string conversion via NNModuleToString |
|
|
- Profiling tools and system information collection |
|
|
- Buck build system integration for Meta-internal testing |
|
|
|
|
|
Key classes: |
|
|
- InputReader/InputWriter: Handle serialization of model inputs/outputs |
|
|
- NNModuleToString: Converts nn.Modules to string representations |
|
|
- BuckTargetWriter: Manages Buck build system integration |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import atexit |
|
|
import copy |
|
|
import cProfile |
|
|
import functools |
|
|
import getpass |
|
|
import inspect |
|
|
import itertools |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import subprocess |
|
|
import sys |
|
|
import tempfile |
|
|
import textwrap |
|
|
from collections import Counter |
|
|
from importlib import import_module |
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar |
|
|
|
|
|
import torch |
|
|
import torch._prims_common as utils |
|
|
import torch._subclasses.meta_utils |
|
|
from torch import Tensor |
|
|
from torch._dynamo.testing import rand_strided |
|
|
from torch._inductor.cpp_builder import normalize_path_separator |
|
|
from torch._prims_common import is_float_dtype |
|
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
|
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter |
|
|
|
|
|
from . import config |
|
|
from .utils import clone_inputs, get_debug_dir |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Sequence |
|
|
|
|
|
from torch.hub import tqdm |
|
|
from torch.storage import UntypedStorage |
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
inductor_config = import_module("torch._inductor.config") |
|
|
use_buck = inductor_config.is_fbcode() |
|
|
|
|
|
if use_buck: |
|
|
import libfb.py.build_info |
|
|
|
|
|
|
|
|
extra_deps = [] |
|
|
extra_imports = "" |
|
|
cur_target = "" |
|
|
if use_buck: |
|
|
extra_deps = [ |
|
|
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", |
|
|
"//caffe2/torch/fb/sparsenn:sparsenn_operators", |
|
|
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", |
|
|
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", |
|
|
] |
|
|
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") |
|
|
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) |
|
|
|
|
|
|
|
|
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] |
|
|
|
|
|
|
|
|
class BuckTargetWriter: |
|
|
def __init__(self, filename: str) -> None: |
|
|
self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) |
|
|
self.target = self.py_file.replace(".py", "") |
|
|
|
|
|
|
|
|
self.path = f"{self.subdir.replace('/', '.')}.{self.target}" |
|
|
self.path = self.path[self.path.find("fbcode.") :] |
|
|
self.path = self.path[7:] |
|
|
|
|
|
|
|
|
tmp = self.subdir |
|
|
tmp = tmp[tmp.find("fbcode/") :][7:] |
|
|
self.cmd_line_path = f"//{tmp}:{self.target}" |
|
|
|
|
|
def build(self) -> str: |
|
|
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) |
|
|
return textwrap.dedent( |
|
|
f""" |
|
|
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") |
|
|
|
|
|
python_binary( |
|
|
name="{self.target}", |
|
|
srcs = ["{self.py_file}"], |
|
|
compile = False, |
|
|
deps = [ |
|
|
"//caffe2:torch", |
|
|
"//caffe2:libtorch", |
|
|
"//caffe2/functorch:functorch", |
|
|
"//triton:triton", |
|
|
"{cur_target}", |
|
|
], |
|
|
cpp_deps = [ |
|
|
{extra_cpp_deps} |
|
|
], |
|
|
main_module = "{self.path}", |
|
|
par_style = "xar", |
|
|
) |
|
|
""" |
|
|
) |
|
|
|
|
|
def write(self, print_msg: bool = True) -> list[str]: |
|
|
target_file = os.path.join(self.subdir, "TARGETS") |
|
|
with open(target_file, "w") as fd: |
|
|
fd.write(self.build()) |
|
|
|
|
|
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] |
|
|
if print_msg: |
|
|
log.warning( |
|
|
"Found an example that reproduces the error. Run this cmd to repro - %s", |
|
|
" ".join(cmd_split), |
|
|
) |
|
|
return cmd_split |
|
|
|
|
|
|
|
|
def minifier_dir() -> str: |
|
|
path = os.path.join(get_debug_dir(), "minifier") |
|
|
if path is None: |
|
|
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" |
|
|
if not os.path.exists(path): |
|
|
os.makedirs(path, exist_ok=True) |
|
|
return path |
|
|
|
|
|
|
|
|
MAX_CONSTANT_NUMEL_INLINE = 4 |
|
|
|
|
|
|
|
|
class NNModuleToString: |
|
|
safe_reprs = [ |
|
|
torch.nn.Linear, |
|
|
torch.nn.Conv1d, |
|
|
torch.nn.Conv2d, |
|
|
torch.nn.Conv3d, |
|
|
torch.nn.BatchNorm1d, |
|
|
torch.nn.BatchNorm2d, |
|
|
torch.nn.BatchNorm3d, |
|
|
torch.nn.LayerNorm, |
|
|
torch.nn.Dropout, |
|
|
torch.nn.Softmax, |
|
|
torch.nn.ReLU, |
|
|
torch.nn.GELU, |
|
|
torch.nn.Identity, |
|
|
torch.nn.MaxPool2d, |
|
|
torch.nn.Embedding, |
|
|
torch.nn.Tanh, |
|
|
torch.nn.ConvTranspose1d, |
|
|
torch.nn.GLU, |
|
|
torch.nn.LSTM, |
|
|
torch.nn.Flatten, |
|
|
torch.nn.AdaptiveAvgPool2d, |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def can_convert_to_string(gm: torch.fx.GraphModule) -> bool: |
|
|
cant_convert = set() |
|
|
for _, module in gm.named_children(): |
|
|
if type(module) not in NNModuleToString.safe_reprs: |
|
|
cant_convert.add(module) |
|
|
|
|
|
if len(cant_convert) > 0: |
|
|
log.warning("We have not tested reprs of some modules - %s", cant_convert) |
|
|
|
|
|
return True |
|
|
|
|
|
@staticmethod |
|
|
def convert(gm: torch.fx.GraphModule) -> str: |
|
|
from torch.nn.modules.module import _addindent |
|
|
|
|
|
tab = " " * 4 |
|
|
|
|
|
model_str = textwrap.dedent( |
|
|
""" |
|
|
from torch.nn import * |
|
|
class Repro(torch.nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
""" |
|
|
) |
|
|
|
|
|
for module_name, module in gm.named_children(): |
|
|
module_str = f"{module.__repr__()}" |
|
|
|
|
|
|
|
|
example_param = next(module.parameters(), None) |
|
|
if example_param is not None and example_param.is_cuda: |
|
|
module_str = f"{module_str}.cuda()" |
|
|
model_str += f"{tab * 2}self.{module_name} = {module_str}\n" |
|
|
|
|
|
for buffer_name, buffer in gm._buffers.items(): |
|
|
if buffer is None: |
|
|
continue |
|
|
|
|
|
if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: |
|
|
from torch._tensor_str import PRINT_OPTS |
|
|
|
|
|
assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE |
|
|
tensor_str = repr(buffer) |
|
|
elif torch.is_floating_point(buffer): |
|
|
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" |
|
|
else: |
|
|
tensor_str = ( |
|
|
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" |
|
|
) |
|
|
if buffer.is_cuda: |
|
|
tensor_str = f"{tensor_str}.cuda()" |
|
|
model_str += ( |
|
|
f"{tab * 2}self.register_buffer('{buffer_name}', {tensor_str})\n" |
|
|
) |
|
|
|
|
|
for param_name, param in gm._parameters.items(): |
|
|
if param is None: |
|
|
continue |
|
|
maybe_device = "" |
|
|
if param.is_cuda: |
|
|
maybe_device = ', device="cuda"' |
|
|
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" |
|
|
model_str += f"{tab * 2}self.{param_name} = {tensor_str}\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_str += f"{_addindent(gm.code, 4)}\n" |
|
|
return model_str |
|
|
|
|
|
|
|
|
@functools.cache |
|
|
def _cuda_system_info_comment() -> str: |
|
|
if not torch.cuda.is_available(): |
|
|
return "# torch.cuda.is_available()==False, no GPU info collected\n" |
|
|
|
|
|
model_str = "# CUDA Info: \n" |
|
|
try: |
|
|
cuda_version_out = subprocess.check_output(["nvcc", "--version"]) |
|
|
cuda_version_lines = cuda_version_out.decode().split("\n") |
|
|
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) |
|
|
model_str += f"{comment}\n" |
|
|
except (FileNotFoundError, subprocess.CalledProcessError): |
|
|
model_str += "# nvcc not found\n" |
|
|
|
|
|
gpu_names = Counter( |
|
|
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) |
|
|
) |
|
|
|
|
|
model_str += "# GPU Hardware Info: \n" |
|
|
for name, count in gpu_names.items(): |
|
|
model_str += f"# {name} : {count} \n" |
|
|
model_str += "\n" |
|
|
return model_str |
|
|
|
|
|
|
|
|
def generate_env_vars_string(*, stable_output: bool = False) -> str: |
|
|
""" |
|
|
Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton. |
|
|
""" |
|
|
if stable_output: |
|
|
return "# env var omitted due to stable_output=True" |
|
|
|
|
|
allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"] |
|
|
skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"] |
|
|
|
|
|
def filter(key: str) -> bool: |
|
|
return any(string in key for string in allow_list) and key not in skip_list |
|
|
|
|
|
config_lines = [ |
|
|
f"os.environ['{key}'] = '{value}'" |
|
|
for key, value in os.environ.items() |
|
|
if filter(key) |
|
|
] |
|
|
config_string = "\n".join(config_lines) |
|
|
return normalize_path_separator(f"""\ |
|
|
import os |
|
|
{config_string} |
|
|
""") |
|
|
|
|
|
|
|
|
def generate_config_string(*, stable_output: bool = False) -> str: |
|
|
import torch._functorch.config |
|
|
import torch._inductor.config |
|
|
|
|
|
if stable_output: |
|
|
return "# config omitted due to stable_output=True" |
|
|
|
|
|
experimental_config = torch.fx.experimental._config.codegen_config() |
|
|
return f"""\ |
|
|
import torch._dynamo.config |
|
|
import torch._inductor.config |
|
|
import torch._functorch.config |
|
|
import torch.fx.experimental._config |
|
|
{torch._dynamo.config.codegen_config()} |
|
|
{torch._inductor.config.codegen_config()} |
|
|
{torch._functorch.config.codegen_config()} |
|
|
{experimental_config} |
|
|
""" |
|
|
|
|
|
|
|
|
def get_minifier_repro_path() -> str: |
|
|
return os.path.join(minifier_dir(), "minifier_launcher.py") |
|
|
|
|
|
|
|
|
def helper_for_dump_minify(contents: str) -> None: |
|
|
minified_repro_path = get_minifier_repro_path() |
|
|
log.warning("Writing minified repro to:\n%s", minified_repro_path) |
|
|
|
|
|
if use_buck: |
|
|
BuckTargetWriter(minified_repro_path).write() |
|
|
try: |
|
|
with open(minified_repro_path, "w") as fd: |
|
|
fd.write(contents) |
|
|
|
|
|
except OSError as e: |
|
|
log.exception("") |
|
|
raise NotImplementedError("Could not write to {minified_repro_path}") from e |
|
|
|
|
|
|
|
|
class AccuracyError(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]: |
|
|
""" |
|
|
This clone inputs is different from utils clone_input. In case of minifier, |
|
|
all the tensors are leaf tensors while creating a new graph. So, we set the |
|
|
requires_grad field w/o checking the leafness of the tensor. |
|
|
""" |
|
|
cloned_inputs = clone_inputs(example_inputs) |
|
|
for idx in range(len(example_inputs)): |
|
|
if isinstance(cloned_inputs[idx], torch.Tensor): |
|
|
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) |
|
|
return cloned_inputs |
|
|
|
|
|
|
|
|
def run_fwd_maybe_bwd( |
|
|
gm: torch.fx.GraphModule, |
|
|
args: Sequence[Any], |
|
|
only_fwd: bool = False, |
|
|
disable_clone: bool = False, |
|
|
) -> Any: |
|
|
""" |
|
|
Runs a forward and possibly backward iteration for a given mod and args. |
|
|
|
|
|
When disable_clone is True, we will use args as-is without cloning. |
|
|
This is higher fidelity but we may destroy the args in the process. |
|
|
""" |
|
|
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass |
|
|
|
|
|
gm = copy.deepcopy(gm) |
|
|
if not disable_clone: |
|
|
args = clone_inputs_retaining_gradness(args) |
|
|
|
|
|
if hasattr(gm, "zero_grad"): |
|
|
gm.zero_grad(True) |
|
|
|
|
|
|
|
|
out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args) |
|
|
|
|
|
if only_fwd: |
|
|
return out |
|
|
if requires_bwd_pass(out): |
|
|
loss = reduce_to_scalar_loss(out) |
|
|
loss.backward() |
|
|
return collect_results(gm, out, None, args) |
|
|
|
|
|
|
|
|
def same_two_models( |
|
|
gm: torch.fx.GraphModule, |
|
|
opt_gm: torch.fx.GraphModule, |
|
|
example_inputs: Sequence[Any], |
|
|
only_fwd: bool = False, |
|
|
*, |
|
|
require_fp64: bool = False, |
|
|
ignore_non_fp: bool = False, |
|
|
) -> bool: |
|
|
""" |
|
|
Check two models have same accuracy. |
|
|
|
|
|
require_fp64: if True, raise an error if we unable to calculate the fp64 reference |
|
|
ignore_non_fp: if True, do not compare outputs which are not floating point. This |
|
|
is mostly useful for the minifier (which wants to avoid quantizing floating point |
|
|
error into integer/boolean error) |
|
|
""" |
|
|
from .utils import same |
|
|
|
|
|
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) |
|
|
|
|
|
fp64_ref = None |
|
|
if config.same_two_models_use_fp64: |
|
|
try: |
|
|
fp64_model, fp64_examples = cast_to_fp64( |
|
|
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) |
|
|
) |
|
|
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) |
|
|
except Exception: |
|
|
if require_fp64: |
|
|
raise RuntimeError( |
|
|
"Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" |
|
|
) |
|
|
log.warning("Could not generate fp64 outputs") |
|
|
|
|
|
try: |
|
|
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
log.exception( |
|
|
"While minifying the program in accuracy minification mode, " |
|
|
"ran into a runtime exception which is likely an unrelated issue." |
|
|
" Skipping this graph." |
|
|
) |
|
|
return True |
|
|
|
|
|
passing = same( |
|
|
ref, |
|
|
res, |
|
|
fp64_ref, |
|
|
tol=config.repro_tolerance, |
|
|
equal_nan=True, |
|
|
ignore_non_fp=ignore_non_fp, |
|
|
) |
|
|
return passing |
|
|
|
|
|
|
|
|
def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
|
|
for node in model.graph.nodes: |
|
|
if ( |
|
|
node.op == "call_function" |
|
|
and node.target == torch.ops.prims.convert_element_type.default |
|
|
): |
|
|
assert len(node.args) == 2 |
|
|
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: |
|
|
node.args = (node.args[0], torch.float64) |
|
|
if node.op == "call_function": |
|
|
dtype = node.kwargs.get("dtype") |
|
|
if dtype is not None and is_float_dtype(dtype): |
|
|
new_kwargs = dict(node.kwargs) |
|
|
new_kwargs["dtype"] = torch.float64 |
|
|
node.kwargs = new_kwargs |
|
|
|
|
|
model.graph.lint() |
|
|
model.recompile() |
|
|
return model |
|
|
|
|
|
|
|
|
def cast_to( |
|
|
dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any] |
|
|
) -> tuple[torch.fx.GraphModule, list[Any]]: |
|
|
from torch.utils._pytree import tree_map |
|
|
|
|
|
model = model.to(dtype) |
|
|
if dtype == torch.float64: |
|
|
|
|
|
|
|
|
model = cast_dtype_args_to_fp64(model) |
|
|
|
|
|
inputs = tree_map( |
|
|
lambda x: x.to(dtype) |
|
|
if isinstance(x, torch.Tensor) and x.is_floating_point() |
|
|
else x, |
|
|
inputs, |
|
|
) |
|
|
return model, inputs |
|
|
|
|
|
|
|
|
def cast_to_fp64( |
|
|
model: torch.fx.GraphModule, inputs: list[Any] |
|
|
) -> tuple[torch.fx.GraphModule, list[Any]]: |
|
|
return cast_to(torch.float64, model, inputs) |
|
|
|
|
|
|
|
|
def backend_accuracy_fails( |
|
|
gm: torch.fx.GraphModule, |
|
|
example_inputs: Sequence[Any], |
|
|
compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], |
|
|
only_fwd: bool = False, |
|
|
*, |
|
|
require_fp64: bool = False, |
|
|
ignore_non_fp: bool = False, |
|
|
) -> bool: |
|
|
try: |
|
|
compiled_gm = compiler_fn( |
|
|
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) |
|
|
) |
|
|
return not same_two_models( |
|
|
gm, |
|
|
compiled_gm, |
|
|
example_inputs, |
|
|
only_fwd, |
|
|
require_fp64=require_fp64, |
|
|
ignore_non_fp=ignore_non_fp, |
|
|
) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
log.exception( |
|
|
"While minifying the program in accuracy minification mode, " |
|
|
"ran into a runtime exception which is likely an unrelated issue." |
|
|
" Skipping this graph" |
|
|
) |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _stride_or_default( |
|
|
stride: Optional[torch._prims_common.StrideType], |
|
|
*, |
|
|
shape: torch._prims_common.ShapeType, |
|
|
) -> torch._prims_common.StrideType: |
|
|
return stride if stride is not None else utils.make_contiguous_strides_for(shape) |
|
|
|
|
|
|
|
|
def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: |
|
|
return lambda x: x if x is not None else d |
|
|
|
|
|
|
|
|
_dtype_or_default = _mk_defaulter(torch.float32) |
|
|
_device_or_default = _mk_defaulter(torch.device("cpu")) |
|
|
_storage_offset_or_default = _mk_defaulter(0) |
|
|
_requires_grad_or_default = _mk_defaulter(False) |
|
|
_is_leaf_or_default = _mk_defaulter(False) |
|
|
|
|
|
|
|
|
class NopInputReader: |
|
|
def __init__(self) -> None: |
|
|
self.total = 0 |
|
|
|
|
|
def storage( |
|
|
self, |
|
|
storage_hash: Optional[str], |
|
|
nbytes: int, |
|
|
*, |
|
|
device: Optional[torch._prims_common.DeviceLikeType] = None, |
|
|
dtype_hint: Optional[torch.dtype] = None, |
|
|
) -> None: |
|
|
self.total += 1 |
|
|
|
|
|
def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]: |
|
|
pass |
|
|
|
|
|
def symint(self, *args: Any, **kwargs: Any) -> Optional[int]: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputReader: |
|
|
def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if save_dir is None: |
|
|
log.warning("no save_dir specified, will generate random data") |
|
|
self.store = ContentStoreReader(save_dir) if save_dir is not None else None |
|
|
self.args: list[Any] = [] |
|
|
self.pbar = pbar |
|
|
|
|
|
def storage( |
|
|
self, |
|
|
storage_hash: Optional[str], |
|
|
nbytes: int, |
|
|
*, |
|
|
device: Optional[torch._prims_common.DeviceLikeType] = None, |
|
|
dtype_hint: Optional[torch.dtype] = None, |
|
|
) -> UntypedStorage: |
|
|
if self.pbar is not None: |
|
|
self.pbar.update(1) |
|
|
device = _device_or_default(device) |
|
|
dtype_hint = _dtype_or_default(dtype_hint) |
|
|
if self.store is not None and storage_hash is not None: |
|
|
try: |
|
|
storage = self.store.read_storage(storage_hash) |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
else: |
|
|
if device != storage.device: |
|
|
log.warning("device mismatch: %s != %s", device, storage.device) |
|
|
|
|
|
|
|
|
|
|
|
return storage |
|
|
log.warning("could not load %s, generating random data instead", storage_hash) |
|
|
shape = (nbytes // dtype_hint.itemsize,) |
|
|
stride = _stride_or_default(None, shape=shape) |
|
|
return rand_strided(shape, stride, dtype_hint, device).untyped_storage() |
|
|
|
|
|
def tensor( |
|
|
self, |
|
|
storage: UntypedStorage, |
|
|
shape: torch._prims_common.ShapeType, |
|
|
stride: Optional[torch._prims_common.StrideType] = None, |
|
|
*, |
|
|
storage_offset: Optional[int] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
requires_grad: Optional[bool] = None, |
|
|
is_leaf: Optional[bool] = None, |
|
|
**metadata: Any, |
|
|
) -> torch.Tensor: |
|
|
stride = _stride_or_default(stride, shape=shape) |
|
|
storage_offset = _storage_offset_or_default(storage_offset) |
|
|
dtype = _dtype_or_default(dtype) |
|
|
is_leaf = _is_leaf_or_default(is_leaf) |
|
|
requires_grad = _requires_grad_or_default(requires_grad) |
|
|
t = torch.tensor( |
|
|
[], dtype=dtype, device=storage.device, requires_grad=requires_grad |
|
|
) |
|
|
with torch.no_grad(): |
|
|
t.set_(storage, storage_offset, shape, stride) |
|
|
if not is_leaf: |
|
|
|
|
|
with torch.enable_grad(): |
|
|
t = t.clone(memory_format=torch.preserve_format) |
|
|
with torch.no_grad(): |
|
|
t.set_(storage, storage_offset, shape, stride) |
|
|
assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf |
|
|
torch._utils.set_tensor_metadata(t, metadata) |
|
|
self.args.append(t) |
|
|
return t |
|
|
|
|
|
def symint(self, val: Any) -> Any: |
|
|
self.args.append(val) |
|
|
return val |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputWriter: |
|
|
def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None: |
|
|
self._lines: list[str] = [] |
|
|
|
|
|
self.storage_counter = itertools.count() |
|
|
self.save_dir = save_dir |
|
|
self.store = ( |
|
|
ContentStoreWriter(save_dir, stable_hash=stable_hash) |
|
|
if save_dir is not None |
|
|
else None |
|
|
) |
|
|
self.seen_storages: dict[StorageWeakRef, str] = {} |
|
|
|
|
|
def lines(self) -> list[str]: |
|
|
r = [ |
|
|
"def load_args(reader):", |
|
|
] |
|
|
r.extend(f" {l}" for l in self._lines) |
|
|
|
|
|
|
|
|
r.append("load_args._version = 0") |
|
|
return r |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def storage( |
|
|
self, |
|
|
untyped_storage: UntypedStorage, |
|
|
*, |
|
|
device_hint: Optional[torch._prims_common.DeviceLikeType] = None, |
|
|
dtype_hint: Optional[torch.dtype] = None, |
|
|
) -> str: |
|
|
ws = StorageWeakRef(untyped_storage) |
|
|
v = self.seen_storages.get(ws) |
|
|
if v is not None: |
|
|
return v |
|
|
v = f"buf{next(self.storage_counter)}" |
|
|
maybe_dtype_hint = "" |
|
|
if _dtype_or_default(None) != _dtype_or_default(dtype_hint): |
|
|
maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" |
|
|
|
|
|
|
|
|
maybe_device = "" |
|
|
device = untyped_storage.device |
|
|
if device.type == "meta": |
|
|
assert device_hint is not None |
|
|
device = device_hint |
|
|
if _device_or_default(None) != device: |
|
|
maybe_device = f", device={device!r}" |
|
|
nbytes = untyped_storage.nbytes() |
|
|
storage_hash = None |
|
|
if self.store is not None and untyped_storage.device.type != "meta": |
|
|
storage_hash = self.store.write_storage(untyped_storage) |
|
|
self._lines.append( |
|
|
f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" |
|
|
) |
|
|
self.seen_storages[ws] = v |
|
|
return v |
|
|
|
|
|
def tensor(self, name: str, t: torch.Tensor) -> None: |
|
|
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq |
|
|
|
|
|
storage = self.storage( |
|
|
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device |
|
|
) |
|
|
args = [] |
|
|
|
|
|
if not statically_known_true( |
|
|
sym_eq(_stride_or_default(None, shape=t.shape), t.stride()) |
|
|
): |
|
|
args.append(str(tuple(t.stride()))) |
|
|
if _dtype_or_default(None) != t.dtype: |
|
|
args.append(f"dtype={t.dtype!r}") |
|
|
if not statically_known_true( |
|
|
_storage_offset_or_default(None) == t.storage_offset() |
|
|
): |
|
|
args.append(f"storage_offset={t.storage_offset()!r}") |
|
|
tensor_metadata = torch._utils.get_tensor_metadata(t) |
|
|
if tensor_metadata: |
|
|
args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) |
|
|
if _requires_grad_or_default(None) != t.requires_grad: |
|
|
args.append(f"requires_grad={t.requires_grad!r}") |
|
|
is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) |
|
|
if _is_leaf_or_default(None) != is_leaf: |
|
|
args.append(f"is_leaf={is_leaf!r}") |
|
|
self._lines.append( |
|
|
"reader.tensor(" |
|
|
+ ", ".join([storage, str(tuple(t.shape)), *args]) |
|
|
+ f") # {name}" |
|
|
) |
|
|
|
|
|
def unsupported(self, name: str, arg: Any) -> None: |
|
|
|
|
|
self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}") |
|
|
|
|
|
|
|
|
if isinstance(arg, (list, tuple)): |
|
|
self._lines.append('"""') |
|
|
for i, a in enumerate(arg): |
|
|
name_i = f"{name}[{i}]" |
|
|
if isinstance(a, torch.Tensor): |
|
|
self.tensor(name_i, a) |
|
|
elif isinstance(a, (int, torch.SymInt)): |
|
|
self.symint(name_i, a) |
|
|
else: |
|
|
self.unsupported(name_i, a) |
|
|
self._lines.append('"""') |
|
|
|
|
|
|
|
|
def const(self, name: str) -> None: |
|
|
self._lines.append( |
|
|
f"reader.const({name!r}) # {name}, filtered out during compilation" |
|
|
) |
|
|
|
|
|
|
|
|
def symint(self, name: str, val: Any) -> None: |
|
|
if isinstance(val, torch.SymInt): |
|
|
val = val.node.hint |
|
|
self._lines.append(f"reader.symint({val!r}) # {name}") |
|
|
|
|
|
|
|
|
def aot_graph_input_parser( |
|
|
func: Callable[[list[Tensor]], list[Tensor]], |
|
|
device: str = "cuda", |
|
|
sym_shapes: Optional[dict[str, int]] = None, |
|
|
default_sym_shape: Optional[int] = None, |
|
|
) -> dict[str, Any]: |
|
|
""" |
|
|
Takes in a function which has been printed with print_readable() and constructs kwargs to run it. |
|
|
|
|
|
Handles Tensor inputs, Symints, and a graph module which might have tensor constants. |
|
|
|
|
|
Consider a function `forward` defined as follows: |
|
|
|
|
|
def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): |
|
|
_tensor_constant0: "i64[4190]" = self._tensor_constant0 |
|
|
# Further implementation |
|
|
|
|
|
kwargs = aot_graph_input_parser(forward) |
|
|
forward(**kwargs) |
|
|
""" |
|
|
|
|
|
from torch.utils._dtype_abbrs import dtype_abbrs |
|
|
|
|
|
dtype_map: dict[str, torch.dtype] = { |
|
|
value: key for key, value in dtype_abbrs.items() |
|
|
} |
|
|
dtype_pattern: str = "|".join(dtype_abbrs.values()) |
|
|
|
|
|
|
|
|
source = inspect.getsource(func) |
|
|
|
|
|
|
|
|
tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" |
|
|
tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" |
|
|
sym_shape_regex = r"Sym\((s\d+)\)" |
|
|
|
|
|
class TensorContainer: |
|
|
"Container for tensors as attributes" |
|
|
|
|
|
|
|
|
kwargs: dict[str, Any] = {} |
|
|
|
|
|
sym_shapes_dict: dict[str, int] = sym_shapes or {} |
|
|
|
|
|
def get_sym_int(symint: str) -> int: |
|
|
torch._check( |
|
|
symint in sym_shapes_dict or default_sym_shape is not None, |
|
|
lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", |
|
|
) |
|
|
return sym_shapes_dict.get(symint, default_sym_shape) |
|
|
|
|
|
def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor: |
|
|
|
|
|
resolved_shape = [] |
|
|
dynamic_dims = [] |
|
|
for i, dim in enumerate(shape): |
|
|
dim = dim.strip() |
|
|
if "s" in dim: |
|
|
s = get_sym_int(dim) |
|
|
resolved_shape.append(s) |
|
|
dynamic_dims.append(i) |
|
|
else: |
|
|
if dim: |
|
|
resolved_shape.append(int(dim)) |
|
|
|
|
|
constructor = torch.randn if dtype.is_floating_point else torch.zeros |
|
|
out = constructor(resolved_shape, dtype=dtype, device=device) |
|
|
for d in dynamic_dims: |
|
|
torch._dynamo.mark_dynamic(out, d) |
|
|
return out |
|
|
|
|
|
|
|
|
annotations = func.__annotations__ |
|
|
for param, annotation in annotations.items(): |
|
|
|
|
|
if param == "return": |
|
|
continue |
|
|
|
|
|
match = re.search(tensor_regex, annotation) |
|
|
if match: |
|
|
data_type, shape_str = match.groups() |
|
|
shape = tuple(shape_str.split(",")) |
|
|
dtype = dtype_map[data_type] |
|
|
kwargs[param] = gen_tensor(shape, dtype) |
|
|
|
|
|
match = re.search(sym_shape_regex, annotation) |
|
|
if match: |
|
|
kwargs[param] = get_sym_int(match.group(1)) |
|
|
|
|
|
if "self" in inspect.signature(func).parameters: |
|
|
container = TensorContainer() |
|
|
kwargs["self"] = container |
|
|
for match in re.finditer(tensor_assignment_regex, source): |
|
|
attr_name, data_type, shape_str, _ = match.groups() |
|
|
shape = tuple(shape_str.split(",")) |
|
|
dtype = dtype_map[data_type] |
|
|
setattr(container, attr_name, gen_tensor(shape, dtype)) |
|
|
|
|
|
return kwargs |
|
|
|
|
|
|
|
|
def profile_to_file(filename: str) -> Callable[[T], T]: |
|
|
""" |
|
|
Decorator to cProfile a given function and save the result to disk on process exit. |
|
|
|
|
|
Args: |
|
|
filename: filename to save profile to |
|
|
""" |
|
|
prof = cProfile.Profile() |
|
|
filename = os.path.abspath(os.path.expanduser(filename)) |
|
|
|
|
|
def decorator(fn: Any) -> Any: |
|
|
@functools.wraps(fn) |
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any: |
|
|
prof.enable() |
|
|
try: |
|
|
return fn(*args, **kwargs) |
|
|
finally: |
|
|
prof.disable() |
|
|
|
|
|
return wrapper |
|
|
|
|
|
def save_it() -> None: |
|
|
prof.dump_stats(filename) |
|
|
sys.stderr.write( |
|
|
textwrap.dedent( |
|
|
f"""\ |
|
|
Wrote profile to {filename}, view with: |
|
|
|
|
|
snakeviz {filename} |
|
|
|
|
|
""" |
|
|
) |
|
|
) |
|
|
|
|
|
atexit.register(save_it) |
|
|
return decorator |
|
|
|