File size: 13,682 Bytes
36cbb94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import torch
from torch._dynamo.eval_frame import innermost_fn
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
import inspect

import dis
from types import CodeType
from typing import List, Callable, Dict, Union, Set
from dataclasses import dataclass
import contextlib

class CodeProxy:
    instances: Dict[str, "CodeProxy"] = {}
    used_instances: Set[str] = set()

    @staticmethod
    def get_new_name(name: str):
        i = 0
        new_name = name
        if new_name.endswith(":"):
            name = name[:-1]
            while True:
                new_name = f"{name}_{i}"
                if new_name not in CodeProxy.instances:
                    break
                i += 1
        return new_name

    @staticmethod
    def consume_new_name(name: str):
        new_name = CodeProxy.get_new_name(name)
        CodeProxy.instances[new_name] = None
        return new_name

    @staticmethod
    def decompile_with_name(code: CodeType, name: str, skip_decompile=False):
        from depyf.utils import decompile_ensure
        if hasattr(code, "__code__"):
            code = code.__code__
        if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"):
            src = open(code.co_filename).read()
            new_name = code.co_name
        else:
            new_name = CodeProxy.get_new_name(name)
            if not skip_decompile:
                src = decompile_ensure(code, new_name)
            else:
                src = ""
        self = CodeProxy(src)
        self.name = new_name
        self.code = f"""<details>
  <summary>{self.name}</summary>

  ```python
{self.raw_code}
  ```
</details>
"""
        CodeProxy.instances[self.name] = self
        return self

    def __init__(self, code: str):
        # Don't directly use this constructor. Use decompile_with_name instead.
        self.raw_code = "".join(
            ["  " + line + "\n" for line in code.splitlines() if line.strip() != ""])

    def __str__(self):
        CodeProxy.used_instances.add(self.name)
        return self.name

    @contextlib.contextmanager
    @staticmethod
    def record():
        CodeProxy.used_instances = set()
        yield CodeProxy.used_instances


@dataclass
class CacheResult:
    original_code: CodeType
    transformed_code: CodeType
    guard: List[str]
    compiled_subgraph: Callable
    compiled_subgraph_proxy: CodeProxy
    transformed_code_proxy: CodeProxy
    referenced_global_functions: Dict[str, "DynamoOptimizationResult"]

    def __init__(self, original_code, module, cache):
        self.original_code = original_code

        cpp_guard = False

        # starting from https://github.com/pytorch/pytorch/pull/138896 ,
        # pytorch uses `guard_manager` instead of `check_fn` to store the
        # guards
        attr_name = "guard_manager" if hasattr(cache, "guard_manager") else "check_fn"

        guard_manager = getattr(cache, attr_name)

        try:
            klass = getattr(torch._dynamo.guards, "GuardManagerWrapper", None) or \
                getattr(torch._dynamo.guards, "GuardManager", None) or \
                getattr(torch._C._dynamo.guards, "GuardManager", None)
            assert klass is not None
            cpp_guard = isinstance(guard_manager, klass)
        except Exception:
            pass

        if not cpp_guard:
            # for old version of pytorch,
            # `guard_manager` is a plain python function
            guard_codes = guard_manager.code_parts
            freevar_names = guard_manager.__code__.co_freevars
            freevar_values = [x.cell_contents for x in guard_manager.__closure__]
        else:
            # keep the logic synced with
            # https://github.com/pytorch/pytorch/blob/7b6b10417d8616ebd7a42b06528c5c2b2fded55a/torch/_dynamo/guards.py#L262
            tensor_aliasing_guard_seen = False
            def visit(root, ans):
                nonlocal tensor_aliasing_guard_seen
                for leaf_guard in root.get_leaf_guards():
                    if isinstance(leaf_guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):
                        if not tensor_aliasing_guard_seen:
                            tensor_aliasing_guard_seen = True
                        else:
                            continue
                    append_guard_code(leaf_guard, ans)
                for child in root.get_child_managers():
                    visit(child, ans)
            guard_codes = []
            root = guard_manager.root

            # Add guards in RootGuardManager
            visit(root, guard_codes)
            # Add guards in epilogue lambda guards
            if hasattr(root, "get_epilogue_lambda_guards"):
                for lambda_guard in root.get_epilogue_lambda_guards():
                    append_guard_code(lambda_guard, guard_codes)

            if guard_manager.closure_vars is None:
                freevar_names = tuple()
                freevar_values = []
            else:
                freevar_names = tuple(guard_manager.closure_vars.keys())
                freevar_values = list(guard_manager.closure_vars.values())

        self.guard = guard_codes
        self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)}
        code = cache.code

        compiled_subgraphs = [
            name for name in code.co_names if name.startswith("__compiled")]
        assert len(compiled_subgraphs) <= 1

        if compiled_subgraphs:
            # deal with compiled_subgraph
            self.compiled_subgraph = innermost_fn(module[compiled_subgraphs[0]])
            # subgraph does not need decompile
            self.compiled_subgraph_proxy = CodeProxy.decompile_with_name(
                self.compiled_subgraph, compiled_subgraphs[0], skip_decompile=True)
        else:
            self.compiled_subgraph = None
            self.compiled_subgraph_proxy = None
        # deal with transformed_code
        self.transformed_code = code
        self.transformed_code_proxy = CodeProxy.decompile_with_name(
            self.transformed_code, "transformed_code:")
        resume_fns = [
            name for name in code.co_names if name.startswith("__resume")]
        self.referenced_global_functions = {}
        for name in resume_fns:
            self.referenced_global_functions[name] = DynamoOptimizationResult(
                original_code=module[name].__code__,
                function_name=name,
                module=module)

    def to_data(self):
        return {
            "guard": self.guard,
            "transformed_code": str(
                self.transformed_code_proxy),
            "compiled_subgraph": str(
                self.compiled_subgraph_proxy) if self.compiled_subgraph_proxy is not None else '"No compiled subgraph."',
            "referenced_global_functions": {
                name: fn.to_data() for name,
                fn in self.referenced_global_functions.items()}}


@dataclass
class DynamoOptimizationResult:
    function_name: str
    module: dict
    original_code: CodeType
    source_code_proxy: CodeProxy
    transformed_code_entries: List[CacheResult]

    def __init__(self, original_code, function_name=None, module=None):
        self.original_code = original_code
        if function_name is None:
            self.function_name = original_code.co_name
        else:
            self.function_name = function_name
        self.module = module
        caches = _debug_get_cache_entry_list(original_code)
        self.transformed_code_entries = [
            CacheResult(original_code, module, cache) for cache in caches]
        self.source_code_proxy = CodeProxy.decompile_with_name(
            self.original_code, self.function_name)

    def to_data(self):
        data = {
            "function_name": self.function_name,
            "source_code": str(
                self.source_code_proxy),
            "transformed_code_entries": [
                entry.to_data() for entry in self.transformed_code_entries]}
        return data

    def to_src(self):
        raw_code = self.source_code_proxy.raw_code

        # prepare function signature, from `def toy_example(a, b)` to `def
        # transformed_toy_example(a, b)`
        signature = raw_code.splitlines()[0].replace(
            "def ", "def transformed_", 1)
        code = signature.strip()

        # prepare args for guards, like `L = {"a": a, "b": b}`
        code_obj = self.original_code
        normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount
        arg_names = code_obj.co_varnames[:normal_arg_count]
        arg_dict = "__local_dict = {" + \
            ", ".join([f'"{name}": {name}' for name in arg_names]) + "}"
        code += "\n" + " " * 4 + arg_dict
        code += "\n" + " " * 4 + "__global_dict = globals()"

        additional_code = ""

        for entry in self.transformed_code_entries:

            # prepare guards, like `def guard_0(L):\n    return a > 0 and b >
            # 0`
            freevars = "".join([f"{name} = '''{value}'''\n" for name, value in entry.freevars.items() if name not in ["__builtins__"]])
            if freevars:
                freevars = "# Note: the following variables are used inside the guard function.\n" + freevars
            guard_lines = [" " * 4 + "__guard_hit = True\n"]
            for x in entry.guard:
                guard_lines.append(" " * 4 + f"__guard_hit = __guard_hit and {x}\n")
            guard_lines.append(" " * 4 + "return __guard_hit\n")
            guard = "".join(guard_lines)
            if entry.transformed_code_proxy.name.startswith("__transformed_code_"):
                guard_func_name = entry.transformed_code_proxy.name.replace("__transformed_code_", "__guard_")
            else:
                guard_func_name = CodeProxy.consume_new_name("guard:")
            additional_code += "\n" + freevars + f"def {guard_func_name}(L, G, **___kwargs_ignored):\n" + guard

            if entry.compiled_subgraph_proxy is not None:
                # prepare compiled subgraph, like `__compiled_fn_0`
                subgraph_name = entry.compiled_subgraph_proxy.name
                additional_code += "\n"
                additional_code += f"# Note: please refer to the graph code in {subgraph_name}*.py.\n"
                additional_code += f"# Captured Graph: Dynamo generated graph (debuggable when using eager backend).\n"
                additional_code += f"# Joint graph: joint forward+backward graph from aot autograd.\n"
                additional_code += f"# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).\n"
                additional_code += f"# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).\n"
                additional_code += f"# AFTER XXX: graph processed by inductor (not debuggable).\n"
                additional_code += f"def {subgraph_name}(*args, **kwargs):\n    pass\n"

            # prepare transformed code, like `transformed_code_0`
            additional_code += "\n" + \
                remove_indentation(entry.transformed_code_proxy.raw_code) + "\n"

            for name, func in entry.referenced_global_functions.items():
                additional_code = func.to_src() + additional_code

            code += "\n" + " " * 4 + \
                f"if {guard_func_name}(__local_dict, __global_dict):\n" + " " * 8 + f"return {entry.transformed_code_proxy.name}({', '.join(arg_names)})"

        additional_code += "\n" + "# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.\n" + \
            remove_indentation(self.source_code_proxy.raw_code) + "\n"

        code += "\n" + " " * 4 + "# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.\n" + \
            " " * 4 + f"return {self.source_code_proxy.name}({', '.join(arg_names)})"
        return additional_code + code + \
            f"\n\n#============ end of {self.function_name} ============#\n"


def remove_indentation(code: str):
    lines = code.splitlines()
    indent = len(lines[0]) - len(lines[0].lstrip())
    return "".join([line[indent:] + "\n" for line in lines])

def append_guard_code(guard, ans):
    for verbose_str in guard.verbose_code_parts():
        verbose_str = verbose_str.strip()
        ans.append(verbose_str)

from contextlib import contextmanager

@contextmanager
def lock_on_file(path_template):
    lock_path = path_template + ".lock"
    from filelock import FileLock
    import os
    lock = FileLock(lock_path)
    try:
        with lock:
            yield
    finally:
        pass


def write_code_to_file_template(src, path_template):
    with lock_on_file(path_template):
        import os
        count = 0
        while True:
            new_filepath = path_template % str(count)
            if not os.path.exists(new_filepath):
                with open(new_filepath, "w") as f:
                    f.write(src)
                break
            # might be a hash collision
            existing_code = open(new_filepath).read()
            if existing_code == src:
                break
            count += 1
        return new_filepath


def get_current_compiled_fn_name():
    import torch
    from torch._dynamo.bytecode_transformation import _unique_id_counter
    from copy import copy
    # torch.compile already called the next, we should add minus 1 to get the
    # correct name
    current_count = next(copy(_unique_id_counter)) - 1
    return "__compiled_fn_" + str(current_count)