koichi12's picture
Add files using upload-large-folder tool
36cbb94 verified
raw
history blame
13.7 kB
import torch
from torch._dynamo.eval_frame import innermost_fn
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
import inspect
import dis
from types import CodeType
from typing import List, Callable, Dict, Union, Set
from dataclasses import dataclass
import contextlib
class CodeProxy:
instances: Dict[str, "CodeProxy"] = {}
used_instances: Set[str] = set()
@staticmethod
def get_new_name(name: str):
i = 0
new_name = name
if new_name.endswith(":"):
name = name[:-1]
while True:
new_name = f"{name}_{i}"
if new_name not in CodeProxy.instances:
break
i += 1
return new_name
@staticmethod
def consume_new_name(name: str):
new_name = CodeProxy.get_new_name(name)
CodeProxy.instances[new_name] = None
return new_name
@staticmethod
def decompile_with_name(code: CodeType, name: str, skip_decompile=False):
from depyf.utils import decompile_ensure
if hasattr(code, "__code__"):
code = code.__code__
if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"):
src = open(code.co_filename).read()
new_name = code.co_name
else:
new_name = CodeProxy.get_new_name(name)
if not skip_decompile:
src = decompile_ensure(code, new_name)
else:
src = ""
self = CodeProxy(src)
self.name = new_name
self.code = f"""<details>
<summary>{self.name}</summary>
```python
{self.raw_code}
```
</details>
"""
CodeProxy.instances[self.name] = self
return self
def __init__(self, code: str):
# Don't directly use this constructor. Use decompile_with_name instead.
self.raw_code = "".join(
[" " + line + "\n" for line in code.splitlines() if line.strip() != ""])
def __str__(self):
CodeProxy.used_instances.add(self.name)
return self.name
@contextlib.contextmanager
@staticmethod
def record():
CodeProxy.used_instances = set()
yield CodeProxy.used_instances
@dataclass
class CacheResult:
original_code: CodeType
transformed_code: CodeType
guard: List[str]
compiled_subgraph: Callable
compiled_subgraph_proxy: CodeProxy
transformed_code_proxy: CodeProxy
referenced_global_functions: Dict[str, "DynamoOptimizationResult"]
def __init__(self, original_code, module, cache):
self.original_code = original_code
cpp_guard = False
# starting from https://github.com/pytorch/pytorch/pull/138896 ,
# pytorch uses `guard_manager` instead of `check_fn` to store the
# guards
attr_name = "guard_manager" if hasattr(cache, "guard_manager") else "check_fn"
guard_manager = getattr(cache, attr_name)
try:
klass = getattr(torch._dynamo.guards, "GuardManagerWrapper", None) or \
getattr(torch._dynamo.guards, "GuardManager", None) or \
getattr(torch._C._dynamo.guards, "GuardManager", None)
assert klass is not None
cpp_guard = isinstance(guard_manager, klass)
except Exception:
pass
if not cpp_guard:
# for old version of pytorch,
# `guard_manager` is a plain python function
guard_codes = guard_manager.code_parts
freevar_names = guard_manager.__code__.co_freevars
freevar_values = [x.cell_contents for x in guard_manager.__closure__]
else:
# keep the logic synced with
# https://github.com/pytorch/pytorch/blob/7b6b10417d8616ebd7a42b06528c5c2b2fded55a/torch/_dynamo/guards.py#L262
tensor_aliasing_guard_seen = False
def visit(root, ans):
nonlocal tensor_aliasing_guard_seen
for leaf_guard in root.get_leaf_guards():
if isinstance(leaf_guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):
if not tensor_aliasing_guard_seen:
tensor_aliasing_guard_seen = True
else:
continue
append_guard_code(leaf_guard, ans)
for child in root.get_child_managers():
visit(child, ans)
guard_codes = []
root = guard_manager.root
# Add guards in RootGuardManager
visit(root, guard_codes)
# Add guards in epilogue lambda guards
if hasattr(root, "get_epilogue_lambda_guards"):
for lambda_guard in root.get_epilogue_lambda_guards():
append_guard_code(lambda_guard, guard_codes)
if guard_manager.closure_vars is None:
freevar_names = tuple()
freevar_values = []
else:
freevar_names = tuple(guard_manager.closure_vars.keys())
freevar_values = list(guard_manager.closure_vars.values())
self.guard = guard_codes
self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)}
code = cache.code
compiled_subgraphs = [
name for name in code.co_names if name.startswith("__compiled")]
assert len(compiled_subgraphs) <= 1
if compiled_subgraphs:
# deal with compiled_subgraph
self.compiled_subgraph = innermost_fn(module[compiled_subgraphs[0]])
# subgraph does not need decompile
self.compiled_subgraph_proxy = CodeProxy.decompile_with_name(
self.compiled_subgraph, compiled_subgraphs[0], skip_decompile=True)
else:
self.compiled_subgraph = None
self.compiled_subgraph_proxy = None
# deal with transformed_code
self.transformed_code = code
self.transformed_code_proxy = CodeProxy.decompile_with_name(
self.transformed_code, "transformed_code:")
resume_fns = [
name for name in code.co_names if name.startswith("__resume")]
self.referenced_global_functions = {}
for name in resume_fns:
self.referenced_global_functions[name] = DynamoOptimizationResult(
original_code=module[name].__code__,
function_name=name,
module=module)
def to_data(self):
return {
"guard": self.guard,
"transformed_code": str(
self.transformed_code_proxy),
"compiled_subgraph": str(
self.compiled_subgraph_proxy) if self.compiled_subgraph_proxy is not None else '"No compiled subgraph."',
"referenced_global_functions": {
name: fn.to_data() for name,
fn in self.referenced_global_functions.items()}}
@dataclass
class DynamoOptimizationResult:
function_name: str
module: dict
original_code: CodeType
source_code_proxy: CodeProxy
transformed_code_entries: List[CacheResult]
def __init__(self, original_code, function_name=None, module=None):
self.original_code = original_code
if function_name is None:
self.function_name = original_code.co_name
else:
self.function_name = function_name
self.module = module
caches = _debug_get_cache_entry_list(original_code)
self.transformed_code_entries = [
CacheResult(original_code, module, cache) for cache in caches]
self.source_code_proxy = CodeProxy.decompile_with_name(
self.original_code, self.function_name)
def to_data(self):
data = {
"function_name": self.function_name,
"source_code": str(
self.source_code_proxy),
"transformed_code_entries": [
entry.to_data() for entry in self.transformed_code_entries]}
return data
def to_src(self):
raw_code = self.source_code_proxy.raw_code
# prepare function signature, from `def toy_example(a, b)` to `def
# transformed_toy_example(a, b)`
signature = raw_code.splitlines()[0].replace(
"def ", "def transformed_", 1)
code = signature.strip()
# prepare args for guards, like `L = {"a": a, "b": b}`
code_obj = self.original_code
normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount
arg_names = code_obj.co_varnames[:normal_arg_count]
arg_dict = "__local_dict = {" + \
", ".join([f'"{name}": {name}' for name in arg_names]) + "}"
code += "\n" + " " * 4 + arg_dict
code += "\n" + " " * 4 + "__global_dict = globals()"
additional_code = ""
for entry in self.transformed_code_entries:
# prepare guards, like `def guard_0(L):\n return a > 0 and b >
# 0`
freevars = "".join([f"{name} = '''{value}'''\n" for name, value in entry.freevars.items() if name not in ["__builtins__"]])
if freevars:
freevars = "# Note: the following variables are used inside the guard function.\n" + freevars
guard_lines = [" " * 4 + "__guard_hit = True\n"]
for x in entry.guard:
guard_lines.append(" " * 4 + f"__guard_hit = __guard_hit and {x}\n")
guard_lines.append(" " * 4 + "return __guard_hit\n")
guard = "".join(guard_lines)
if entry.transformed_code_proxy.name.startswith("__transformed_code_"):
guard_func_name = entry.transformed_code_proxy.name.replace("__transformed_code_", "__guard_")
else:
guard_func_name = CodeProxy.consume_new_name("guard:")
additional_code += "\n" + freevars + f"def {guard_func_name}(L, G, **___kwargs_ignored):\n" + guard
if entry.compiled_subgraph_proxy is not None:
# prepare compiled subgraph, like `__compiled_fn_0`
subgraph_name = entry.compiled_subgraph_proxy.name
additional_code += "\n"
additional_code += f"# Note: please refer to the graph code in {subgraph_name}*.py.\n"
additional_code += f"# Captured Graph: Dynamo generated graph (debuggable when using eager backend).\n"
additional_code += f"# Joint graph: joint forward+backward graph from aot autograd.\n"
additional_code += f"# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).\n"
additional_code += f"# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).\n"
additional_code += f"# AFTER XXX: graph processed by inductor (not debuggable).\n"
additional_code += f"def {subgraph_name}(*args, **kwargs):\n pass\n"
# prepare transformed code, like `transformed_code_0`
additional_code += "\n" + \
remove_indentation(entry.transformed_code_proxy.raw_code) + "\n"
for name, func in entry.referenced_global_functions.items():
additional_code = func.to_src() + additional_code
code += "\n" + " " * 4 + \
f"if {guard_func_name}(__local_dict, __global_dict):\n" + " " * 8 + f"return {entry.transformed_code_proxy.name}({', '.join(arg_names)})"
additional_code += "\n" + "# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.\n" + \
remove_indentation(self.source_code_proxy.raw_code) + "\n"
code += "\n" + " " * 4 + "# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.\n" + \
" " * 4 + f"return {self.source_code_proxy.name}({', '.join(arg_names)})"
return additional_code + code + \
f"\n\n#============ end of {self.function_name} ============#\n"
def remove_indentation(code: str):
lines = code.splitlines()
indent = len(lines[0]) - len(lines[0].lstrip())
return "".join([line[indent:] + "\n" for line in lines])
def append_guard_code(guard, ans):
for verbose_str in guard.verbose_code_parts():
verbose_str = verbose_str.strip()
ans.append(verbose_str)
from contextlib import contextmanager
@contextmanager
def lock_on_file(path_template):
lock_path = path_template + ".lock"
from filelock import FileLock
import os
lock = FileLock(lock_path)
try:
with lock:
yield
finally:
pass
def write_code_to_file_template(src, path_template):
with lock_on_file(path_template):
import os
count = 0
while True:
new_filepath = path_template % str(count)
if not os.path.exists(new_filepath):
with open(new_filepath, "w") as f:
f.write(src)
break
# might be a hash collision
existing_code = open(new_filepath).read()
if existing_code == src:
break
count += 1
return new_filepath
def get_current_compiled_fn_name():
import torch
from torch._dynamo.bytecode_transformation import _unique_id_counter
from copy import copy
# torch.compile already called the next, we should add minus 1 to get the
# correct name
current_count = next(copy(_unique_id_counter)) - 1
return "__compiled_fn_" + str(current_count)