koichi12's picture
Add files using upload-large-folder tool
36cbb94 verified
raw
history blame
12.6 kB
from .patched_boxed_run import patched_boxed_run
from .patched_lazy_format_graph_code import patched_lazy_format_graph_code
from .patched_load_by_key_path import patched_load_by_key_path
from .patched__exec_with_source import patched__exec_with_source
from typing import List, Tuple, Dict, Union, Callable, Optional, Any
import contextlib
import warnings
import traceback
import dataclasses
import itertools
import sys
import os
import inspect
@dataclasses.dataclass
class DebuggableHook(object):
dump_src_dir: str
log_bytecode: bool
optimized_code_and_module: List =dataclasses.field(default_factory=list, init=False)
def __call__(self, code, new_code):
frame = sys._getframe()
import os
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == code
self.optimized_code_and_module.append([code, frame.f_globals])
from depyf.decompiler import DecompilationError
try:
import os
# replace " "/"<"/">"/"." with "_"
func_name = code.co_name.replace(".", "_").replace("<", "_").replace(">", "_").replace(" ", "_")
filepath_template = os.path.join(
self.dump_src_dir,
f"__transformed_code_%s_for_{func_name}.py")
from depyf.explain.utils import lock_on_file
from depyf.decompiler import Decompiler
# function name and file name are related.
with lock_on_file(filepath_template):
decompiled_and_compiled_back_code = Decompiler.decompile_and_compile_like(code_to_decompile=new_code, reference_code=code, filepath_template=filepath_template)
filename = decompiled_and_compiled_back_code.co_filename
if self.log_bytecode:
with lock_on_file(filename):
import dill
# code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`.
# see https://github.com/pytorch/pytorch/issues/116013 for more details.
with contextlib.suppress(Exception):
dill.dump(code, open(filename + ".original_bytecode", "wb"))
with contextlib.suppress(Exception):
dill.dump(new_code, open(filename + ".transformed_bytecode", "wb"))
with contextlib.suppress(Exception):
dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb"))
# this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487
from torch._dynamo.utils import orig_code_map
from torch._dynamo.convert_frame import output_codes
output_codes.add(decompiled_and_compiled_back_code)
orig_code_map[decompiled_and_compiled_back_code] = code
return decompiled_and_compiled_back_code
except (DecompilationError, SyntaxError) as e:
from io import StringIO
string_io = StringIO()
import dis
print("There is a problem when decompiling and compiling the following code:", file=string_io)
dis.dis(new_code, file=string_io)
print("Please consider submitting an issue to https://github.com/thuml/depyf .", file=string_io)
# do not stop the program for decompilation error and compile error
warnings.warn(string_io.getvalue())
traceback.print_exc()
@contextlib.contextmanager
def patch(parent, name, value):
old_value = getattr(parent, name, None)
if old_value is not None:
setattr(parent, name, value)
try:
yield
finally:
if old_value is not None:
setattr(parent, name, old_value)
@contextlib.contextmanager
def enable_bytecode_hook(hook):
import torch
handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
try:
yield
finally:
handle.remove()
@contextlib.contextmanager
def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):
"""
A context manager to dump debugging information for torch.compile.
It should wrap the code that actually triggers the compilation, rather than
the code that applies ``torch.compile``.
Example:
.. code-block:: python
import torch
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
# main()
# surround the code you want to run inside `with depyf.prepare_debug`
import depyf
with depyf.prepare_debug("./dump_src_dir"):
main()
After running the code, you will find the dumped information in the directory ``dump_src_dir``. The details are organized into the following:
- ``full_code_for_xxx.py`` for each function using torch.compile
- ``__transformed_code_for_xxx.py`` for Python code associated with each graph.
- ``__transformed_code_for_xxx.py.xxx_bytecode`` for Python bytecode, dumped code object, can be loaded via ``dill.load(open("/path/to/file", "wb"))``. Note that the load function might import some modules like transformers. Make sure you have these modules installed.
- ``__compiled_fn_xxx.py`` for each computation graph and its optimization:
- ``Captured Graph``: a plain forward computation graph
- ``Joint Graph``: joint forward-backward graph from AOTAutograd
- ``Forward Graph``: forward graph from AOTAutograd
- ``Backward Graph``: backward graph from AOTAutograd
- ``kernel xxx``: compiled CPU/GPU kernel wrapper from Inductor.
Arguments:
- ``dump_src_dir``: the directory to dump the source code.
- ``clean_wild_fx_code``: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
- ``log_bytecode``: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code).
"""
if not isinstance(dump_src_dir, str):
raise RuntimeError('''You are using an obsolete usage style`depyf.prepare_debug(func=function, dump_src_dir="/path")`. Please use `depyf.prepare_debug(dump_src_dir="/path")` instead, which will automatically capture all compiled functions.''')
import os
import torch
current_line_number = inspect.currentframe().f_lineno + 1
warnings.warn_explicit(f"{__file__}:{current_line_number}: You are trying to debug `torch.compile`. Please make sure the code runs multiple times to cover all the possible branches.", UserWarning, "", 0)
from depyf.utils import safe_create_directory
if not os.path.exists(dump_src_dir):
safe_create_directory(dump_src_dir)
dump_src_dir = os.path.abspath(dump_src_dir)
from .global_variables import data
data["dump_src_dir"] = dump_src_dir
data["unpatched__exec_with_source"] = torch.fx.graph_module._exec_with_source
data["unpatched_load_by_key_path"] = torch._inductor.codecache.PyCodeCache.load_by_key_path
data["unpatched___call__"] = torch._dynamo.eval_frame.OptimizeContext.__call__
data["is_inside_prepare_debug"] = True
bytecode_hook = DebuggableHook(dump_src_dir, log_bytecode)
# patch some functions
with patch(torch.fx.graph_module, "_exec_with_source", patched__exec_with_source), \
patch(torch._inductor.codecache.PyCodeCache, "load_by_key_path", patched_load_by_key_path), \
patch(torch._dynamo.utils.lazy_format_graph_code, "__code__", patched_lazy_format_graph_code.__code__):
# we have to directly manipulate the code object, since the function has been imported in many places.
# simply replacing torch._dynamo.utils.lazy_format_graph_code does not work for those functions.
# Note: `unitest.mock.patch` does not work here, since it will not
# patch the code object. (it will try to delete the code object and
# then set a new code object. The `delattr` will raise an error.)
# enable bytecode hook
with enable_bytecode_hook(bytecode_hook):
try:
yield
finally:
code_names = {x[0].co_name for x in bytecode_hook.optimized_code_and_module}
for code, module in bytecode_hook.optimized_code_and_module:
if code.co_name.startswith("resume_in_") and any(f"resume_in_{name}" in code.co_name for name in code_names):
continue
# https://github.com/pytorch/pytorch/pull/118201 introduces `torch_dynamo_resume_in_` names.
if code.co_name.startswith("torch_dynamo_resume_in_") and any(f"torch_dynamo_resume_in_{name}" in code.co_name for name in code_names):
continue
from depyf.explain import dump_src
from depyf.explain.utils import write_code_to_file_template
from torch._dynamo.eval_frame import innermost_fn, _debug_get_cache_entry_list
entries = _debug_get_cache_entry_list(code)
if not entries:
current_line_number = inspect.currentframe().f_lineno + 1
warnings.warn_explicit(f"{__file__}:{current_line_number}: Code object {code} is compiled but does not have any compiled cache entries. Probably some torch.nn.Module instances are destroyed too early. It is recommended to make sure the torch.nn.Module instances exist after `with depyf.prepare_debug`.", UserWarning, "", 0)
full_src = dump_src(code, module)
filepath_template = os.path.join(dump_src_dir, f"full_code_for_{code.co_name}_%s.py")
full_code_path = write_code_to_file_template(full_src, filepath_template)
for file in os.listdir(dump_src_dir):
name = file.split(os.path.sep)[-1]
# remove *.lock file and possibly fx_graph_code* file
if (clean_wild_fx_code and name.startswith("fx_graph_code")) or name.endswith(".lock"):
try:
# multiple processes may try to remove the same file.
os.remove(os.path.join(dump_src_dir, file))
except OSError:
pass
data["is_inside_prepare_debug"] = False
@contextlib.contextmanager
def debug():
"""
A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix ``full_code_for_`` in the ``dump_src_dir`` argument of :func:`depyf.prepare_debug`, and set breakpoints in their separate ``__transformed_code_`` files according to the function name. Then continue your debugging.
"""
from .global_variables import data
if data["is_inside_prepare_debug"]:
raise RuntimeError("You cannot use `depyf.debug` inside `depyf.prepare_debug`.")
dump_src_dir = data["dump_src_dir"]
import torch
# after https://github.com/pytorch/pytorch/pull/131258
# torch._dynamo.eval_frame.set_eval_frame is not available in the module
# we need to directly access it from the `_C` extension.
callback = torch._C._dynamo.eval_frame.set_eval_frame(False)
# sometimes pytorch use Interpreter to run node by node. This cannot be debugged.
# we patch this function to run the graph function directly.
with patch(torch.fx.Interpreter.boxed_run, "__code__", patched_boxed_run.__code__):
try:
msg = f"`depyf` places a breakpoint here to pause the program. You can check the full source code in files with prefix `full_code_for_` in {dump_src_dir} first, and set breakpoints in their separate files according to the function name. Then continue your debugging."
print(msg)
breakpoint()
yield
finally:
torch._C._dynamo.eval_frame.set_eval_frame(callback)