koichi12 commited on
Commit
36cbb94
·
verified ·
1 Parent(s): 0ae5c3e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/depyf/VERSION.txt +1 -0
  2. .venv/lib/python3.11/site-packages/depyf/__init__.py +24 -0
  3. .venv/lib/python3.11/site-packages/depyf/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/depyf/__pycache__/code_transform.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/depyf/__pycache__/decompiler.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/depyf/__pycache__/optimization.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/depyf/__pycache__/utils.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/depyf/code_transform.py +475 -0
  9. .venv/lib/python3.11/site-packages/depyf/decompiler.py +1312 -0
  10. .venv/lib/python3.11/site-packages/depyf/explain/__init__.py +17 -0
  11. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/__init__.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enable_debugging.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enhance_logging.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/global_variables.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched___call__.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched__exec_with_source.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_boxed_run.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_lazy_format_graph_code.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_load_by_key_path.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/depyf/explain/__pycache__/utils.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/depyf/explain/enable_debugging.py +251 -0
  22. .venv/lib/python3.11/site-packages/depyf/explain/enhance_logging.py +94 -0
  23. .venv/lib/python3.11/site-packages/depyf/explain/global_variables.py +14 -0
  24. .venv/lib/python3.11/site-packages/depyf/explain/patched___call__.py +9 -0
  25. .venv/lib/python3.11/site-packages/depyf/explain/patched__exec_with_source.py +20 -0
  26. .venv/lib/python3.11/site-packages/depyf/explain/patched_boxed_run.py +2 -0
  27. .venv/lib/python3.11/site-packages/depyf/explain/patched_lazy_format_graph_code.py +78 -0
  28. .venv/lib/python3.11/site-packages/depyf/explain/patched_load_by_key_path.py +21 -0
  29. .venv/lib/python3.11/site-packages/depyf/explain/utils.py +338 -0
  30. .venv/lib/python3.11/site-packages/depyf/optimization.py +74 -0
  31. .venv/lib/python3.11/site-packages/depyf/utils.py +90 -0
  32. .venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/INSTALLER +1 -0
  33. .venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/WHEEL +4 -0
  34. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/INSTALLER +1 -0
  35. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/LICENSE +20 -0
  36. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/METADATA +27 -0
  37. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/RECORD +14 -0
  38. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/WHEEL +5 -0
  39. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/entry_points.txt +3 -0
  40. .venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/top_level.txt +1 -0
  41. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__init__.py +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/__init__.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/layers.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/models.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/request.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/utils.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/worker_manager.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/vllm/adapter_commons/layers.py +16 -0
  49. .venv/lib/python3.11/site-packages/vllm/adapter_commons/models.py +105 -0
  50. .venv/lib/python3.11/site-packages/vllm/adapter_commons/request.py +25 -0
.venv/lib/python3.11/site-packages/depyf/VERSION.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.18.0
.venv/lib/python3.11/site-packages/depyf/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import CodeType
2
+ import warnings
3
+
4
+ from .decompiler import Decompiler, decompile
5
+
6
+ try:
7
+ import torch
8
+ torch_version = torch.__version__
9
+ valid = ("dev" not in torch_version and torch_version >= "2.2") or (
10
+ "dev" in torch_version and torch_version.split("dev")[-1] >= "20231020")
11
+ if not valid:
12
+ warnings.warn(
13
+ ("Please use the nightly version of PyTorch to enable bytecode hooks.\n"
14
+ "PyTorch nightly can be installed by: `conda install pytorch-nightly::pytorch torchvision torchaudio -c pytorch-nightly`"))
15
+
16
+ from depyf.explain.enhance_logging import install, uninstall
17
+ from depyf.explain.enable_debugging import prepare_debug, debug
18
+ except ImportError as e:
19
+ # print(e)
20
+ pass
21
+
22
+ import os
23
+
24
+ __version__ = open(f"{os.path.dirname(__file__)}/VERSION.txt").read().strip()
.venv/lib/python3.11/site-packages/depyf/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.51 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/__pycache__/code_transform.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/__pycache__/decompiler.cpython-311.pyc ADDED
Binary file (83.7 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/__pycache__/optimization.cpython-311.pyc ADDED
Binary file (4.57 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/code_transform.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dis
2
+ from typing import List, Tuple, Union, Optional, Callable, Any, Dict, Set
3
+ from types import CodeType
4
+ import ast
5
+ import astor
6
+ from collections import defaultdict
7
+ import dataclasses
8
+ import sys
9
+ import hashlib
10
+
11
+ py311 = sys.version_info >= (3, 11)
12
+ all_jump_opcode_set = set(dis.hasjabs) | set(dis.hasjrel)
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Instruction:
17
+ """A mutable version of dis.Instruction"""
18
+
19
+ opcode: int
20
+ opname: str
21
+ arg: Optional[int]
22
+ argval: Any
23
+ argrepr: str
24
+ offset: Optional[int] = None
25
+ starts_line: Optional[int] = None
26
+ is_jump_target: bool = False
27
+
28
+ def __hash__(self):
29
+ return id(self)
30
+
31
+ def __eq__(self, other):
32
+ return id(self) == id(other)
33
+
34
+ def short_inst_repr(self):
35
+ return f"Instruction(opname={self.opname}, offset={self.offset})"
36
+
37
+ def is_jump(self):
38
+ return self.opcode in all_jump_opcode_set
39
+
40
+ def get_jump_target(self: "Instruction"):
41
+ if self.is_jump() and "to " in self.argrepr:
42
+ return int(self.argrepr.replace("to ", "").strip())
43
+ # seems like a bug, "FOR_ITER" is in `dis.hasjrel`, but its `argval` is
44
+ # an absolute offset
45
+ if self.opcode in dis.hasjabs:
46
+ return self.argval
47
+ elif self.opcode in dis.hasjrel:
48
+ return self.offset + self.argval if not py311 else self.argval
49
+ else:
50
+ raise ValueError(
51
+ f"Instruction {self.opname} does not have jump target")
52
+
53
+
54
+ def convert_instruction(i: dis.Instruction) -> Instruction:
55
+ return Instruction(
56
+ i.opcode,
57
+ i.opname,
58
+ i.arg,
59
+ i.argval,
60
+ i.argrepr,
61
+ i.offset,
62
+ i.starts_line,
63
+ i.is_jump_target,
64
+ )
65
+
66
+
67
+ def nop_instruction(inst: Instruction):
68
+ """Inplace modify an instruction as nop."""
69
+ inst.opname = "NOP"
70
+ inst.opcode = dis.opmap["NOP"]
71
+ inst.arg = 0
72
+ inst.argval = 0
73
+ inst.argrepr = ""
74
+ inst.offset
75
+ inst.starts_line
76
+ inst.is_jump_target = False
77
+ return inst
78
+
79
+
80
+ def propagate_line_nums(instructions: List[Instruction]):
81
+ """Ensure every instruction has line number set in case some are removed"""
82
+ cur_line_no = None
83
+
84
+ def populate_line_num(inst):
85
+ nonlocal cur_line_no
86
+ if inst.starts_line:
87
+ cur_line_no = inst.starts_line
88
+
89
+ inst.starts_line = cur_line_no
90
+
91
+ for inst in instructions:
92
+ populate_line_num(inst)
93
+
94
+
95
+ # ======= begin code borrowed from pytorch/torch/_dynamo/bytecode_transformation.py ===========
96
+ @dataclasses.dataclass
97
+ class ExceptionTableEntry:
98
+ start: int
99
+ end: int
100
+ target: int
101
+ depth: int
102
+ lasti: bool
103
+
104
+ def decode_exception_table_varint(bytes_iter) -> int:
105
+ """
106
+ Inverse of `encode_exception_table_varint`.
107
+ """
108
+ b = next(bytes_iter)
109
+ val = b & 63
110
+ while b & 64:
111
+ val <<= 6
112
+ b = next(bytes_iter)
113
+ val |= b & 63
114
+ return val
115
+
116
+ def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
117
+ """
118
+ Verifies that a list of ExceptionTableEntries will make a well-formed
119
+ jump table: entries are non-empty, sorted, and do not overlap.
120
+ """
121
+ for i in range(len(tab) - 1):
122
+ assert (
123
+ tab[i].start <= tab[i].end
124
+ and tab[i].end < tab[i + 1].start
125
+ and tab[i + 1].start <= tab[i + 1].end
126
+ )
127
+
128
+ def parse_exception_table(exntab) -> List[ExceptionTableEntry]:
129
+ """
130
+ Parse the exception table according to
131
+ https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
132
+ """
133
+ exntab_iter = iter(exntab)
134
+ tab = []
135
+ try:
136
+ while True:
137
+ start = decode_exception_table_varint(exntab_iter) * 2
138
+ length = decode_exception_table_varint(exntab_iter) * 2
139
+ end = start + length - 2
140
+ target = decode_exception_table_varint(exntab_iter) * 2
141
+ dl = decode_exception_table_varint(exntab_iter)
142
+ depth = dl >> 1
143
+ lasti = bool(dl & 1)
144
+ tab.append(ExceptionTableEntry(start, end, target, depth, lasti))
145
+ except StopIteration:
146
+ check_exception_table(tab)
147
+ return tab
148
+ # ======= end code borrowed from pytorch/torch/_dynamo/bytecode_transformation.py ===========
149
+
150
+ def simplify_finally_statement(instructions: List[Instruction]):
151
+ """Simplify finally statement.
152
+ 3.10 finally statement:
153
+ SETUP_FINALLY
154
+ body
155
+ POP_BLOCK
156
+ finally code
157
+ Exception code
158
+ RERAISE
159
+ """
160
+ for i, inst in enumerate(instructions):
161
+ if inst.opname == "SETUP_FINALLY":
162
+ finally_target = inst.get_jump_target()
163
+ reraise_idx = [j for j, _inst in enumerate(
164
+ instructions) if _inst.offset >= finally_target and _inst.opname == "RERAISE"]
165
+ if reraise_idx:
166
+ reraise_index = reraise_idx[0]
167
+ for j, _inst in enumerate(instructions):
168
+ if _inst.offset >= finally_target and j <= reraise_index:
169
+ nop_instruction(_inst)
170
+
171
+
172
+ def nop_unreachable_bytecode(code,
173
+ instructions: List[dis.Instruction]) -> List[dis.Instruction]:
174
+ """Mark unreachable bytecode as NOP."""
175
+ jumps = set(dis.hasjabs) | set(dis.hasjrel)
176
+
177
+ exception_targets = {}
178
+ if py311:
179
+ tab = parse_exception_table(code.co_exceptiontable)
180
+ exception_targets = {entry.target: entry for entry in tab}
181
+
182
+ # difference bwteween `i in deadcode_positions` and `reachable[i] == False`:
183
+ # `i in deadcode_positions` means that the instruction is not reachable, defnitely a NOP
184
+ # `reachable[i] == False` means that the instruction is not reachable currently, but it might be reachable later when we iterate through the instructions
185
+ reachable = [False for x in instructions]
186
+ deadcode_positions = set()
187
+ reachable[0] = True
188
+ # each instruction marks the instruction after it
189
+ for i, inst in enumerate(instructions):
190
+ if inst.is_jump_target or inst.offset in exception_targets:
191
+ # the instruction is the target of a jump
192
+ reachable[i] = True
193
+ # the last instruction does not need to mark any following instructions
194
+ if i == len(instructions) - 1:
195
+ break
196
+ # this instruction is not reachable, nothing to do
197
+ if not reachable[i]:
198
+ continue
199
+ # this instruction is reachable
200
+ # the following instruction is reachable if it is sequential op or
201
+ # conditional jump
202
+ if inst.opname in ["RETURN_VALUE", "BREAK_LOOP"]:
203
+ # the instruction after the return is unreachable
204
+ pass
205
+ elif inst.opcode in jumps:
206
+ if inst.opcode in dis.hasjrel and inst.get_jump_target() == inst.offset:
207
+ # this is a jump to itself, it is regarded as a NOP, per the documentation at
208
+ # https://devguide.python.org/internals/interpreter/#jumps
209
+ reachable[i] = False
210
+ reachable[i + 1] = True
211
+ continue
212
+ if "IF" in inst.opname or "FOR_ITER" in inst.opname or "SETUP_LOOP" in inst.opname:
213
+ # the fallback block is always reachable for conditional jumps
214
+ reachable[i + 1] = True
215
+ elif inst.opname in ["SETUP_FINALLY", "SETUP_WITH", "BEFORE_WITH"]:
216
+ # the with/finally block is always reachable
217
+ reachable[i + 1] = True
218
+ else:
219
+ # this is a direct jump, the target is reachable
220
+ # we further check if any outside instructions jump into in-between instructions
221
+ # if not, we can mark this instruction as unreachable, too
222
+ # later, in-between instructions will be marked as unreachable (NOP)
223
+ # and the interpreter will slide through all the NOP directly
224
+ # to the target
225
+ jump_forwards = [j for j, instruct in enumerate(
226
+ instructions) if instruct.offset >= inst.get_jump_target()]
227
+ if len(jump_forwards):
228
+ j = jump_forwards[0]
229
+ if j > i:
230
+ smallest_jump_in = j
231
+ has_jump_in = False
232
+
233
+ for ii, inst_ii in enumerate(instructions[i: j]):
234
+ # in python 3.11 exception table
235
+ # exception target indicates a jump target from many instructions
236
+ # and therefore it is treated as a jump-in
237
+ if inst_ii.offset in exception_targets:
238
+ has_jump_in = True
239
+ smallest_jump_in = min(
240
+ smallest_jump_in, ii)
241
+
242
+ for ii, inst_ii in enumerate(instructions):
243
+ try:
244
+ jump_location = inst_ii.get_jump_target()
245
+ if (ii < i or ii > j) and (jump_location >= inst.offset and jump_location < instructions[j].offset):
246
+ has_jump_in = True
247
+ smallest_jump_in = min(
248
+ smallest_jump_in, ii)
249
+ except Exception:
250
+ pass
251
+ if not has_jump_in:
252
+ reachable[i] = False
253
+ for _ in range(i, smallest_jump_in):
254
+ deadcode_positions.add(_)
255
+ else:
256
+ reachable[i + 1] = True
257
+
258
+ for i in deadcode_positions:
259
+ reachable[i] = False
260
+
261
+ # mark unreachable instructions as NOP
262
+ for inst, flag in zip(instructions, reachable):
263
+ if not flag:
264
+ nop_instruction(inst)
265
+
266
+
267
+ def add_indentation(code: str, indentation: int = 4) -> str:
268
+ """Add indentation to code."""
269
+ return "".join(
270
+ " " *
271
+ indentation +
272
+ line +
273
+ "\n" for line in code.splitlines())
274
+
275
+
276
+ def remove_indentation(code: str, indentation: int = 4) -> str:
277
+ """Remove indentation from code."""
278
+ return "".join(line[indentation:] + "\n" for line in code.splitlines())
279
+
280
+
281
+ class RemoveAssignmentTransformer(ast.NodeTransformer):
282
+ def __init__(self,
283
+ temp_name: str,
284
+ temp_occurrences: Dict[str,
285
+ List[ast.Name]]):
286
+ # optimize one temp_name at a time
287
+ self.temp_name = temp_name
288
+ self.temp_occurrences = temp_occurrences
289
+
290
+ def visit_Assign(self, node):
291
+ # single assimngment like `temp = xxx`
292
+ if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
293
+ name = node.targets[0].id
294
+ # the assignment is like `temp = xxx`
295
+ if name == self.temp_name:
296
+ if len(self.temp_occurrences[name]) == 1:
297
+ return ast.Expr(value=node.value)
298
+ elif len(self.temp_occurrences[name]) == 3 and isinstance(self.temp_occurrences[name][-1], bool):
299
+ # we save the `xxx` here
300
+ self.temp_occurrences[name].append(node.value)
301
+ if self.temp_occurrences[name][-2]:
302
+ return None
303
+ return node
304
+
305
+
306
+ class RemoveAssignment2Transformer(ast.NodeTransformer):
307
+ def __init__(self,
308
+ temp_name: str,
309
+ temp_occurrences: Dict[str,
310
+ List[ast.Name]]):
311
+ # optimize one temp_name at a time
312
+ self.temp_name = temp_name
313
+ self.temp_occurrences = temp_occurrences
314
+
315
+ def visit_Name(self, node):
316
+ name = node.id
317
+ if name == self.temp_name and len(self.temp_occurrences[name]) == 4 and isinstance(
318
+ self.temp_occurrences[name][-2], bool):
319
+ if self.temp_occurrences[name][-2]:
320
+ return self.temp_occurrences[name][-1]
321
+ return node
322
+
323
+
324
+ def get_parents(node):
325
+ """Collect all parent nodes of a given node."""
326
+ parents = []
327
+ while node:
328
+ parents.append(node)
329
+ node = getattr(node, "parent", None)
330
+ return parents
331
+
332
+
333
+ def set_parents(node, parent=None):
334
+ """Recursively set the parent attribute for each node."""
335
+ for child in ast.iter_child_nodes(node):
336
+ child.parent = parent
337
+ set_parents(child, child)
338
+
339
+
340
+ def lowest_common_parent(node1, node2):
341
+ """Get the lowest common parent for two nodes."""
342
+ parents1 = get_parents(node1)
343
+ parents2 = get_parents(node2)
344
+
345
+ # Reverse the parents list to start comparing from the root.
346
+ parents1.reverse()
347
+ parents2.reverse()
348
+
349
+ last_common = None
350
+ for p1, p2 in zip(parents1, parents2):
351
+ if p1 is p2:
352
+ last_common = p1
353
+ else:
354
+ break
355
+ return last_common, p1, p2
356
+
357
+
358
+ def remove_some_temp(
359
+ source_code: str,
360
+ temp_prefix: str,
361
+ indentation: int = 4) -> str:
362
+ tree = ast.parse(source_code)
363
+ set_parents(tree)
364
+
365
+ temp_occurrences = defaultdict(list)
366
+ for node in ast.walk(tree):
367
+ if isinstance(node, ast.Name) and node.id.startswith(temp_prefix):
368
+ temp_occurrences[node.id].append(node)
369
+
370
+ for key in temp_occurrences:
371
+ if len(temp_occurrences[key]) == 2:
372
+ node1 = temp_occurrences[key][0]
373
+ node2 = temp_occurrences[key][1]
374
+ parent, parent1, parent2 = lowest_common_parent(node1, node2)
375
+ assignment_node = node1 if isinstance(
376
+ node1.parent, ast.Assign) else node2
377
+ assignment_parent = parent1 if isinstance(
378
+ node1.parent, ast.Assign) else parent2
379
+ indentation_nodes = (
380
+ ast.FunctionDef,
381
+ ast.AsyncFunctionDef,
382
+ ast.For,
383
+ ast.AsyncFor,
384
+ ast.While,
385
+ ast.If,
386
+ ast.Try,
387
+ ast.With,
388
+ ast.AsyncWith,
389
+ ast.ClassDef)
390
+ # we cannot remove the assignment if the assignment `temp=xxx` is
391
+ # in an indentation block while the usage of `temp` is not
392
+ can_merge = not isinstance(assignment_parent, indentation_nodes)
393
+ temp_occurrences[key].append(can_merge)
394
+ tree = RemoveAssignmentTransformer(key, temp_occurrences).visit(tree)
395
+ tree = RemoveAssignment2Transformer(key, temp_occurrences).visit(tree)
396
+
397
+ reconstructed_code = astor.to_source(tree, indent_with=" " * indentation)
398
+ return reconstructed_code
399
+
400
+
401
+ class IdentifierReplacer(ast.NodeTransformer):
402
+
403
+ # def visit_Name(self, node):
404
+ # return ast.copy_location(ast.Name(id='PLACEHOLDER', ctx=node.ctx), node)
405
+
406
+ def visit_FunctionDef(self, node):
407
+ node.name = 'PLACEHOLDER'
408
+ return self.generic_visit(node)
409
+
410
+ # def visit_AsyncFunctionDef(self, node):
411
+ # node.name = 'PLACEHOLDER'
412
+ # return self.generic_visit(node)
413
+
414
+ # def visit_ClassDef(self, node):
415
+ # node.name = 'PLACEHOLDER'
416
+ # return self.generic_visit(node)
417
+
418
+ # def visit_Attribute(self, node):
419
+ # node.attr = 'PLACEHOLDER'
420
+ # return self.generic_visit(node)
421
+
422
+
423
+ def fix_irregular_code(
424
+ old_bytecode: CodeType,
425
+ src_code: str,
426
+ add_local_variables: Optional[List[str]]=None,
427
+ add_cellvars: Optional[List[str]]=None,
428
+ ) -> str:
429
+ function_name = src_code.split("(")[0].split()[-1]
430
+ new_code = src_code
431
+ if add_local_variables is not None or add_cellvars is not None:
432
+ lines = src_code.splitlines()
433
+ header = lines[0]
434
+ body = lines[1:]
435
+ headers = [header]
436
+ if add_local_variables:
437
+ added_line = "; ".join(f"{x} = None" for x in add_local_variables)
438
+ added_line = " " + added_line + " # this line helps Python to generate bytecode with at least the same number of local variables as the original function\n"
439
+ headers.append(added_line)
440
+ if add_cellvars:
441
+ added_line = "return " + ", ".join(x for x in add_cellvars)
442
+ added_line = (
443
+ " def __helper_for_cellvars():\n"
444
+ " # this function helps Python to generate bytecode with at least the same number of cellvars as the original function\n"
445
+ ) + " " + added_line
446
+ headers.append(added_line)
447
+ new_code = "".join([x + "\n" for x in headers + body])
448
+
449
+ freevars = old_bytecode.co_freevars
450
+ if freevars:
451
+ tmp_code = (
452
+ "def __helper_outer_function():\n"
453
+ " # this is a helper function to help compilers generate bytecode to read capture variables from closures, rather than reading values from global scope. The value of these variables does not matter, and will be determined in runtime.\n"
454
+ )
455
+ for freevar in freevars:
456
+ tmp_code += f" {freevar} = None\n"
457
+ tmp_code += add_indentation(new_code, 4)
458
+ new_code = tmp_code
459
+
460
+ # make sure the new bytecode has at least the same number of local variables as the original bytecode
461
+ # this seems to fix the test failure in https://github.com/thuml/depyf/actions/runs/7004325219/job/19051829613 , and might be related with the discussion in https://github.com/pytorch/pytorch/pull/111883
462
+ compiled_code = compile(new_code, "noname", "exec")
463
+ from .utils import collect_all_code_objects
464
+ code_objects = collect_all_code_objects(compiled_code)
465
+ target_code = [x for x in code_objects if x.co_name == function_name][0]
466
+
467
+ missing_local_variables = set(old_bytecode.co_varnames) - set(target_code.co_varnames)
468
+ missing_cellvars = set(old_bytecode.co_cellvars) - set(target_code.co_cellvars)
469
+
470
+ if missing_local_variables or missing_cellvars:
471
+ return fix_irregular_code(
472
+ old_bytecode, src_code,
473
+ add_local_variables=sorted(list(missing_local_variables)),
474
+ add_cellvars=sorted(list(missing_cellvars)))
475
+ return new_code
.venv/lib/python3.11/site-packages/depyf/decompiler.py ADDED
@@ -0,0 +1,1312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple program to transform bytecode into more readable source code."""
2
+
3
+ import sys
4
+ import os
5
+ import dis
6
+ from types import CodeType
7
+ from typing import List, Tuple, Dict, Union, Callable, Optional
8
+ import dataclasses
9
+ import inspect
10
+ import functools
11
+ from collections import defaultdict
12
+ import contextlib
13
+
14
+ from .code_transform import (
15
+ nop_unreachable_bytecode,
16
+ nop_instruction,
17
+ add_indentation,
18
+ remove_indentation,
19
+ remove_some_temp,
20
+ propagate_line_nums,
21
+ convert_instruction,
22
+ simplify_finally_statement,
23
+ Instruction,
24
+ )
25
+ from .utils import (
26
+ get_function_signature,
27
+ )
28
+
29
+
30
+ class DecompilationError(Exception):
31
+ """Custom exception class for decompilation."""
32
+
33
+ def __init__(self, message=""):
34
+ self.message = message
35
+ super().__init__(self.message)
36
+
37
+ def __str__(self):
38
+ return f'DecompilationError: {self.message}'
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class DecompilerState:
43
+ """State of decompiler, keep track of the evaluation stack, as well as the decompiled source code."""
44
+ source_code: str
45
+ stack: list
46
+ inside_loop: bool = False
47
+ loop_start_index: int = -1 # inclusive
48
+ loop_end_index: int = -1 # exclusive
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Decompiler:
53
+ """A decompiler for a code object."""
54
+ code: CodeType
55
+ temp_count: int = 0
56
+ temp_prefix: str = "__temp_"
57
+ state: DecompilerState = dataclasses.field(
58
+ default_factory=lambda: DecompilerState(
59
+ source_code="", stack=[]))
60
+ indentation: int = 4
61
+
62
+ @contextlib.contextmanager
63
+ def new_state(self, stack, inside_loop=False, loop_start_index=-1, loop_end_index=-1):
64
+ """Create a new state for decompiler."""
65
+ state = DecompilerState(source_code="", stack=stack, inside_loop=inside_loop, loop_start_index=loop_start_index, loop_end_index=loop_end_index)
66
+ old_state = self.state
67
+ if old_state.inside_loop and not state.inside_loop:
68
+ # inherit the loop state from the old state
69
+ state.inside_loop = old_state.inside_loop
70
+ state.loop_start_index = old_state.loop_start_index
71
+ state.loop_end_index = old_state.loop_end_index
72
+ self.state = state
73
+ yield
74
+ self.state = old_state
75
+
76
+
77
+ # ==================== Unsupported Instructions =============================
78
+ def unimplemented_instruction(self, inst: Instruction):
79
+ raise NotImplementedError(f"Unsupported instruction: {inst.opname}")
80
+
81
+ GET_YIELD_FROM_ITER = unimplemented_instruction
82
+
83
+ # we don't support try-except/try-finally
84
+ POP_EXCEPT = WITH_EXCEPT_START = JUMP_IF_NOT_EXC_MATCH = CHECK_EG_MATCH = PUSH_EXC_INFO = PREP_RERAISE_STAR = WITH_CLEANUP_FINISH = CALL_FINALLY = POP_FINALLY = WITH_CLEANUP_START = SETUP_EXCEPT = CHECK_EXC_MATCH = CLEANUP_THROW = unimplemented_instruction
85
+
86
+ # we don't support async/await
87
+ GET_AWAITABLE = GET_AITER = GET_ANEXT = END_ASYNC_FOR = BEFORE_ASYNC_WITH = SETUP_ASYNC_WITH = SEND = ASYNC_GEN_WRAP = unimplemented_instruction
88
+
89
+ CACHE = unimplemented_instruction
90
+
91
+ # we don't know these instructions
92
+ PRINT_EXPR = COPY_DICT_WITHOUT_KEYS = unimplemented_instruction
93
+
94
+ # we only support bytecode for functions
95
+ IMPORT_STAR = unimplemented_instruction
96
+
97
+ YIELD_FROM = SETUP_ANNOTATIONS = LOAD_BUILD_CLASS = MATCH_MAPPING = MATCH_SEQUENCE = MATCH_KEYS = MATCH_CLASS = unimplemented_instruction
98
+
99
+ # don't find any interesting use case for these instructions
100
+ CALL_INTRINSIC_2 = unimplemented_instruction
101
+
102
+
103
+ # ==================== NOP Instructions =============================
104
+
105
+ def generic_nop(self, inst: Instruction):
106
+ pass
107
+
108
+ # "EXTENDED_ARG" is treated as NOP here, because it has been handled by `dis.get_instructions`.
109
+ # The extended args are already merged into the following instruction's
110
+ # `inst.argval`.
111
+ EXTENDED_ARG = generic_nop
112
+
113
+ NOP = RESUME = SETUP_LOOP = POP_BLOCK = PRECALL = BEGIN_FINALLY = END_FINALLY = generic_nop
114
+
115
+ MAKE_CELL = generic_nop
116
+
117
+ RERAISE = generic_nop
118
+
119
+ # our FOR_ITER is different from CPython's FOR_ITER (as it does not need
120
+ # to explicitly consider the case of exhausted iterator), so we don't need
121
+ # to do anything here
122
+ END_FOR = generic_nop
123
+
124
+
125
+ # ==================== Load Instructions =============================
126
+
127
+ def LOAD_CONST(self, inst: Instruction):
128
+ """Push a constant onto the stack.
129
+ `inst.argval` is the constant value, we have to use `repr` to get the source code
130
+ """
131
+ can_repr = False
132
+ try:
133
+ can_repr = eval(repr(inst.argval)) == inst.argval
134
+ except BaseException:
135
+ pass
136
+ if can_repr:
137
+ self.state.stack.append(repr(inst.argval))
138
+ else:
139
+ if isinstance(inst.argval, type):
140
+ # Don't know why a class type get here, support this corner
141
+ # case anyway.
142
+ module = inst.argval.__module__
143
+ name = inst.argval.__name__
144
+ self.state.source_code += "import importlib\n"
145
+ temp_name = self.get_temp_name()
146
+ self.state.source_code += f'{temp_name} = importlib.import_module("{module}").{name}\n'
147
+ self.state.stack.append(temp_name)
148
+ elif inst.argrepr.startswith("torch."):
149
+ # Don't know why torch.xxx get here, support this corner case
150
+ # anyway. This deals with something like `torch.float`.
151
+ self.state.source_code += "import torch\n"
152
+ temp_name = self.get_temp_name()
153
+ self.state.source_code += f'{temp_name} = {inst.argval}\n'
154
+ self.state.stack.append(temp_name)
155
+ elif isinstance(inst.argval, CodeType):
156
+ # used in MAKE_FUNCTION
157
+ self.state.stack.append(inst.argval)
158
+ else:
159
+ self.state.stack.append(f"'__co_consts[{inst.arg}]'")
160
+
161
+ def generic_load(self, inst: Instruction):
162
+ """`inst.argval` is the variable name, in string"""
163
+ if "NULL + " in inst.argrepr:
164
+ # Python 3.11 support
165
+ self.state.stack.append(None)
166
+ if inst.argrepr.startswith("."):
167
+ # list/set/tuple comprehension.
168
+ self.state.stack.append(inst.argval.replace(".", "comp_arg_"))
169
+ else:
170
+ self.state.stack.append(inst.argval)
171
+
172
+ LOAD_FAST = LOAD_FAST_AND_CLEAR = LOAD_FAST_CHECK = LOAD_GLOBAL = LOAD_DEREF = LOAD_NAME = LOAD_CLASSDEREF = LOAD_CLOSURE = generic_load
173
+
174
+ def LOAD_LOCALS(self, inst: Instruction):
175
+ self.state.stack.append("locals()")
176
+ self.replace_mutable_tos_with_temp()
177
+
178
+ def LOAD_FROM_DICT_OR_GLOBALS(self, inst: Instruction):
179
+ tos = self.state.stack.pop()
180
+ self.state.stack.append(
181
+ f"{tos}[{inst.argval}] if '{inst.argval}' in {tos} else {inst.argval}")
182
+ self.replace_mutable_tos_with_temp()
183
+
184
+ LOAD_FROM_DICT_OR_DEREF = LOAD_FROM_DICT_OR_GLOBALS
185
+
186
+ def MAKE_FUNCTION(self, inst: Instruction):
187
+ if sys.version_info < (3, 11):
188
+ qual_name = self.state.stack.pop()
189
+ try:
190
+ qual_name = eval(qual_name)
191
+ except Exception:
192
+ pass
193
+ # qual_name for inner function is something like `LongformerEncoder.forward.<locals>.create_custom_forward`
194
+ # get the last part of the name, which is the function name
195
+ func_name = qual_name.split(".")[-1]
196
+ if "<" in func_name:
197
+ self.state.source_code += f'"original function name {func_name} is illegal, use a temp name."\n'
198
+ func_name = self.get_temp_name()
199
+ else:
200
+ # Python 3.11 support, see
201
+ # https://docs.python.org/3.11/library/dis.html#opcode-MAKE_FUNCTION
202
+ func_name = self.get_temp_name()
203
+ code = self.state.stack.pop()
204
+ if inst.argval & 0x08:
205
+ # has closure
206
+ self.state.stack.pop()
207
+ if inst.argval & 0x04:
208
+ # has annotations
209
+ self.state.stack.pop()
210
+ kw_defaults = self.state.stack.pop() if inst.argval & 0x02 else {}
211
+ defaults = self.state.stack.pop() if inst.argval & 0x01 else ()
212
+ if len(kw_defaults) or len(defaults):
213
+ print(
214
+ "Function with default arguments is not supported, ignore the default arguments")
215
+ this_index = self.index_of(inst.offset)
216
+ immediately_used = False
217
+ if self.instructions[this_index + 1].opname == "STORE_FAST":
218
+ # the function is immediately stored in a variable, use that
219
+ # variable name
220
+ func_name = self.instructions[this_index + 1].argval
221
+ immediately_used = True
222
+ inner_func = Decompiler(code).decompile(overwite_fn_name=func_name)
223
+ self.state.source_code += inner_func
224
+ if not immediately_used:
225
+ self.state.stack.append(func_name)
226
+ else:
227
+ # skip one instruction
228
+ return this_index + 2
229
+
230
+ def COPY_FREE_VARS(self, inst: Instruction):
231
+ # this opcode is used to copy free variables from the outer scope to the closure
232
+ # it affects the frame, but not the stack or the source code
233
+ pass
234
+
235
+ def LOAD_ATTR(self, inst: Instruction):
236
+ lhs = str(self.state.stack.pop())
237
+ rhs = inst.argval
238
+ if rhs.isidentifier():
239
+ self.state.stack.append(f"{lhs}.{rhs}")
240
+ else:
241
+ self.state.stack.append(f"getattr({lhs}, {repr(rhs)})")
242
+
243
+ def LOAD_SUPER_ATTR(self, inst: Instruction):
244
+ # not tested
245
+ self_obj = self.state.stack.pop()
246
+ cls_obj = self.state.stack.pop()
247
+ super_obj = self.state.stack.pop()
248
+ self.state.stack.append(
249
+ f"{super_obj}({cls_obj}, {self_obj}).{inst.argval}")
250
+ self.replace_mutable_tos_with_temp()
251
+
252
+ def LOAD_METHOD(self, inst: Instruction):
253
+ self.state.stack.append(f"{self.state.stack.pop()}.{inst.argval}")
254
+
255
+ def LOAD_ASSERTION_ERROR(self, inst: Instruction):
256
+ self.state.stack.append("AssertionError")
257
+
258
+ def PUSH_NULL(self, inst: Instruction):
259
+ # the `None` object is used to represent `NULL` in python bytecode
260
+ self.state.stack.append(None)
261
+
262
+ def GET_ITER(self, inst: Instruction):
263
+ tos = self.state.stack.pop()
264
+ self.state.stack.append(f"iter({tos})")
265
+
266
+ # ==================== Store Instructions =============================
267
+
268
+ def generic_store(self, inst: Instruction):
269
+ left = inst.argval
270
+ right = self.state.stack.pop()
271
+ if left != right:
272
+ # Inplace operations like `+=` will pop the variable name from the stack, and push the result back to the stack
273
+ # leading to a source code like `x = x`. We need to avoid this.
274
+ self.state.source_code += f"{left} = {right}\n"
275
+
276
+ STORE_FAST = STORE_GLOBAL = STORE_DEREF = STORE_NAME = generic_store
277
+
278
+ def STORE_SUBSCR(self, inst: Instruction):
279
+ index = self.state.stack.pop()
280
+ x = self.state.stack.pop()
281
+ value = self.state.stack.pop()
282
+ self.state.source_code += f"{x}[{index}] = {value}\n"
283
+
284
+ def STORE_SLICE(self, inst: Instruction):
285
+ # not tested, code according to
286
+ # https://docs.python.org/3.12/library/dis.html#opcode-STORE_SLICE
287
+ end = self.state.stack.pop()
288
+ start = self.state.stack.pop()
289
+ container = self.state.stack.pop()
290
+ value = self.state.stack.pop()
291
+ self.state.source_code += f"{container}[{start}:{end}] = {value}\n"
292
+
293
+ def STORE_ATTR(self, inst: Instruction):
294
+ x = self.state.stack.pop()
295
+ value = self.state.stack.pop()
296
+ self.state.source_code += f"{x}.{inst.argval} = {value}\n"
297
+
298
+ # ==================== Del Instructions =============================
299
+
300
+ def DELETE_SUBSCR(self, inst: Instruction):
301
+ index = self.state.stack.pop()
302
+ x = self.state.stack.pop()
303
+ self.state.source_code += f"del {x}[{index}]\n"
304
+
305
+ def generic_delete(self, inst: Instruction):
306
+ self.state.source_code += f"del {inst.argval}\n"
307
+
308
+ DELETE_NAME = DELETE_GLOBAL = DELETE_DEREF = generic_delete
309
+ # `DELETE_FAST` just reduces the ref count by one
310
+ # it does not occur as code `del x` in the source code
311
+ DELETE_FAST = generic_nop
312
+
313
+ def DELETE_ATTR(self, inst: Instruction):
314
+ x = self.state.stack.pop()
315
+ self.state.source_code += f"del {x}.{inst.argval}\n"
316
+
317
+ # ==================== Import Instructions =============================
318
+ def IMPORT_NAME(self, inst: Instruction):
319
+ # TODO: check multi-level import, e.g. `import a.b.c`
320
+ name = inst.argval.split(".")[0]
321
+ fromlist = self.state.stack.pop()
322
+ level = self.state.stack.pop()
323
+ self.state.source_code += f"{name} = __import__({repr(inst.argval)}, fromlist={fromlist}, level={level})\n"
324
+ self.state.stack.append(name)
325
+
326
+ def IMPORT_FROM(self, inst: Instruction):
327
+ name = inst.argval
328
+ module = self.state.stack[-1]
329
+ self.state.source_code += f"{name} = {module}.{name}\n"
330
+ self.state.stack.append(name)
331
+
332
+ # ==================== Unary Instructions =============================
333
+
334
+ def generic_unary(self, inst: Instruction):
335
+ op = {
336
+ "UNARY_NEGATIVE": "-",
337
+ "UNARY_POSITIVE": "+",
338
+ "UNARY_INVERT": "~",
339
+ "UNARY_NOT": "not",
340
+ }[inst.opname]
341
+ self.state.stack.append(f"({op} {self.state.stack.pop()})")
342
+
343
+ UNARY_NEGATIVE = UNARY_POSITIVE = UNARY_INVERT = UNARY_NOT = generic_unary
344
+
345
+ def GET_LEN(self, inst: Instruction):
346
+ self.state.stack.append(f"len({self.state.stack[-1]})")
347
+
348
+ # ==================== Binary Instructions =============================
349
+ def generic_binary(self, inst: Instruction):
350
+ rhs = self.state.stack.pop()
351
+ lhs = self.state.stack.pop()
352
+ op = {
353
+ "BINARY_MULTIPLY": "*",
354
+ "BINARY_ADD": "+",
355
+ "BINARY_SUBTRACT": "-",
356
+ "BINARY_TRUE_DIVIDE": "/",
357
+ "BINARY_FLOOR_DIVIDE": "//",
358
+ "BINARY_MODULO": "%",
359
+ "BINARY_POWER": "**",
360
+ "BINARY_AND": "&",
361
+ "BINARY_OR": "|",
362
+ "BINARY_XOR": "^",
363
+ "BINARY_LSHIFT": "<<",
364
+ "BINARY_RSHIFT": ">>",
365
+ "BINARY_MATRIX_MULTIPLY": "@",
366
+ }[inst.opname]
367
+ self.state.stack.append(f"({lhs} {op} {rhs})")
368
+
369
+ BINARY_MULTIPLY = BINARY_ADD = BINARY_SUBTRACT = BINARY_TRUE_DIVIDE = BINARY_FLOOR_DIVIDE = BINARY_MODULO = BINARY_POWER = BINARY_AND = BINARY_OR = BINARY_XOR = BINARY_LSHIFT = BINARY_RSHIFT = BINARY_MATRIX_MULTIPLY = generic_binary
370
+
371
+ def BINARY_SUBSCR(self, inst: Instruction):
372
+ rhs = self.state.stack.pop()
373
+ lhs = self.state.stack.pop()
374
+ self.state.stack.append(f"{lhs}[{rhs}]")
375
+
376
+ def BINARY_SLICE(self, inst: Instruction):
377
+ end = self.state.stack.pop()
378
+ start = self.state.stack.pop()
379
+ container = self.state.stack.pop()
380
+ self.state.stack.append(f"{container}[{start}:{end}]")
381
+
382
+ # ==================== Binary Inplace Instructions =======================
383
+ def generic_inplace_binary(self, inst: Instruction):
384
+ rhs = self.state.stack.pop()
385
+ lhs = self.state.stack.pop()
386
+ op = {
387
+ "INPLACE_MULTIPLY": "*",
388
+ "INPLACE_ADD": "+",
389
+ "INPLACE_SUBTRACT": "-",
390
+ "INPLACE_TRUE_DIVIDE": "/",
391
+ "INPLACE_FLOOR_DIVIDE": "//",
392
+ "INPLACE_MODULO": "%",
393
+ "INPLACE_POWER": "**",
394
+ "INPLACE_AND": "&",
395
+ "INPLACE_OR": "|",
396
+ "INPLACE_XOR": "^",
397
+ "INPLACE_LSHIFT": "<<",
398
+ "INPLACE_RSHIFT": ">>",
399
+ "INPLACE_MATRIX_MULTIPLY": "@",
400
+ }[inst.opname]
401
+ self.state.source_code += f"{lhs} {op}= {rhs}\n"
402
+ self.state.stack.append(lhs)
403
+
404
+ INPLACE_MULTIPLY = INPLACE_ADD = INPLACE_SUBTRACT = INPLACE_TRUE_DIVIDE = INPLACE_FLOOR_DIVIDE = INPLACE_MODULO = INPLACE_POWER = INPLACE_AND = INPLACE_OR = INPLACE_XOR = INPLACE_LSHIFT = INPLACE_RSHIFT = INPLACE_MATRIX_MULTIPLY = generic_inplace_binary
405
+
406
+ def BINARY_OP(self, inst: Instruction):
407
+ rhs = self.state.stack.pop()
408
+ lhs = self.state.stack.pop()
409
+ if "=" in inst.argrepr:
410
+ self.state.source_code += f"{lhs} {inst.argrepr} {rhs}\n"
411
+ self.state.stack.append(lhs)
412
+ else:
413
+ self.state.stack.append(f"({lhs} {inst.argrepr} {rhs})")
414
+
415
+ # ==================== Conditional Test Instructions =====================
416
+ def COMPARE_OP(self, inst: Instruction):
417
+ rhs = self.state.stack.pop()
418
+ lhs = self.state.stack.pop()
419
+ self.state.stack.append(f"({lhs} {inst.argval} {rhs})")
420
+
421
+ def IS_OP(self, inst: Instruction):
422
+ rhs = self.state.stack.pop()
423
+ lhs = self.state.stack.pop()
424
+ op = "is" if inst.argval == 0 else "is not"
425
+ self.state.stack.append(f"({lhs} {op} {rhs})")
426
+
427
+ def CONTAINS_OP(self, inst: Instruction):
428
+ rhs = self.state.stack.pop()
429
+ lhs = self.state.stack.pop()
430
+ op = "in" if inst.argval == 0 else "not in"
431
+ self.state.stack.append(f"({lhs} {op} {rhs})")
432
+
433
+ # ==================== Control Flow Instructions =============================
434
+
435
+ def BREAK_LOOP(self, inst: Instruction):
436
+ self.state.source_code += "break\n"
437
+
438
+ def generic_abs_jump(self, inst: Instruction):
439
+ jump_offset = inst.get_jump_target()
440
+ jump_index = self.index_of(jump_offset)
441
+ if self.state.inside_loop:
442
+ if jump_index >= self.state.loop_end_index:
443
+ self.state.source_code += "break\n"
444
+ elif jump_index <= self.state.loop_start_index:
445
+ self.state.source_code += "continue\n"
446
+ else:
447
+ return jump_index
448
+ else:
449
+ return jump_index
450
+
451
+ JUMP_ABSOLUTE = JUMP_FORWARD = JUMP_BACKWARD = JUMP_BACKWARD_NO_INTERRUPT = generic_abs_jump
452
+
453
+ def RETURN_VALUE(self, inst: Instruction):
454
+ self.state.source_code += f"return {self.state.stack[-1]}\n"
455
+ self.state.stack.pop()
456
+
457
+ def RETURN_CONST(self, inst: Instruction):
458
+ self.state.source_code += f"return {inst.argval}\n"
459
+
460
+ def YIELD_VALUE(self, inst: Instruction):
461
+ if sys.version_info >= (3, 12):
462
+ raise NotImplementedError(
463
+ "YIELD_VALUE is not supported in Python 3.12")
464
+ self.state.source_code += f"yield {self.state.stack[-1]}\n"
465
+
466
+ def RETURN_GENERATOR(self, inst: Instruction):
467
+ # we don't handle generator/coroutine, add this to support simple yield
468
+ self.state.stack.append(None)
469
+
470
+ def GEN_START(self, inst: Instruction):
471
+ # self.state.stack.pop()
472
+ assert inst.argval == 0, "Only generator expression is supported"
473
+
474
+ def generic_jump_if(self, inst: Instruction):
475
+ """How we support if-else:
476
+
477
+ Failed idea: try to paritition the block of instructions into if and else.
478
+ This is not possible, as the if-else block might have overlapping instructions.
479
+ Take this function as an example:
480
+
481
+ def f(a):
482
+ b = 1 if a else 2
483
+ print(b)
484
+
485
+ The bytecode is:
486
+ 2 0 LOAD_FAST 0 (a)
487
+ 2 POP_JUMP_IF_FALSE 4 (to 8)
488
+ 4 LOAD_CONST 1 (1)
489
+ 6 JUMP_FORWARD 1 (to 10)
490
+ >> 8 LOAD_CONST 2 (2)
491
+ >> 10 STORE_FAST 1 (b)
492
+
493
+ 3 12 LOAD_GLOBAL 0 (print)
494
+ 14 LOAD_FAST 1 (b)
495
+ 16 CALL_FUNCTION 1
496
+ 18 POP_TOP
497
+ 20 LOAD_CONST 0 (None)
498
+ 22 RETURN_VALUE
499
+
500
+ The instructions for if branch: 2, 4, 6, 10
501
+ The instructions for else branch: 8, 10
502
+ They share the same instruction 10, so we cannot partition the block into if and else.
503
+
504
+ Another example:
505
+
506
+ def f():
507
+ g(arg1=a if a is not None else b, arg2=2)
508
+ print(1)
509
+
510
+ The bytecode is:
511
+
512
+ 2 0 LOAD_GLOBAL 0 (g)
513
+ 2 LOAD_GLOBAL 1 (a)
514
+ 4 LOAD_CONST 0 (None)
515
+ 6 IS_OP 1
516
+ 8 POP_JUMP_IF_FALSE 7 (to 14)
517
+ 10 LOAD_GLOBAL 1 (a)
518
+ 12 JUMP_FORWARD 1 (to 16)
519
+ >> 14 LOAD_GLOBAL 2 (b)
520
+ >> 16 LOAD_CONST 1 (2)
521
+ 18 LOAD_CONST 2 (('arg1', 'arg2'))
522
+ 20 CALL_FUNCTION_KW 2
523
+ 22 POP_TOP
524
+
525
+ 3 24 LOAD_GLOBAL 3 (print)
526
+ 26 LOAD_CONST 3 (1)
527
+ 28 CALL_FUNCTION 1
528
+ 30 POP_TOP
529
+ 32 LOAD_CONST 0 (None)
530
+ 34 RETURN_VALUE
531
+
532
+ The instructions for if branch: 8, 14, 16, 18, 20, 22
533
+ The instructions for else branch: 10, 12, 16, 18, 20, 22
534
+ They share the same instructions 16, 18, 20, 22, so we cannot partition the block into if and else.
535
+
536
+ Current idea:
537
+
538
+ We take advantage of the following fact:
539
+
540
+ This code snippet:
541
+
542
+ if cond:
543
+ if-body
544
+ else:
545
+ else-body
546
+ rest-body
547
+
548
+ is equivalent to:
549
+
550
+ if cond:
551
+ if-body
552
+ rest-body
553
+ else:
554
+ else-body
555
+ rest-body
556
+
557
+ By duplicating the rest-body, we can decompile the if-else block separately. And they will have some duplicated code.
558
+
559
+ Of course, we don't want to duplicate too long code, so we need to find the end of if-else block.
560
+ The current heuristic is to find the first store/return/jump/for-iter instruction after the if-else block (because they are indicators that we will generate meaningful source code).
561
+ """
562
+ jump_offset = inst.get_jump_target()
563
+ jump_index = self.index_of(jump_offset)
564
+ this_index = self.index_of(inst.offset)
565
+ cond = self.state.stack[-1]
566
+ fallthrough_stack = self.state.stack.copy()
567
+ jump_stack = self.state.stack.copy()
568
+
569
+ if "IF_NOT_NONE" in inst.opname:
570
+ cond = f"({cond} is None)"
571
+ elif "IF_NONE" in inst.opname:
572
+ cond = f"({cond} is not None)"
573
+ elif "IF_TRUE" in inst.opname:
574
+ cond = f"(not {cond})"
575
+ elif "IF_FALSE" in inst.opname:
576
+ cond = f"{cond}"
577
+
578
+ # POP_AND_JUMP / JUMP_OR_POP
579
+ if "POP_JUMP" in inst.opname:
580
+ jump_stack.pop()
581
+ fallthrough_stack.pop()
582
+ elif "OR_POP" in inst.opname:
583
+ fallthrough_stack.pop()
584
+
585
+ end_index_candidates = [len(self.instructions)]
586
+ if self.state.inside_loop:
587
+ end_index_candidates.append(self.state.loop_end_index)
588
+
589
+ def qualified_jump(i: Instruction):
590
+ return i.is_jump() and i.get_jump_target() >= jump_offset
591
+
592
+ jump_targets = [i.get_jump_target() for i in self.instructions[this_index: jump_index] if qualified_jump(i)]
593
+
594
+ if not jump_targets:
595
+ # this is a jump back, we will generate a ``continue`` statement
596
+ # normally `if` condition is for the fallthrough code, but in this case
597
+ # we need to generate the `if` condition for the jump code
598
+ # therefore the condition is reversed
599
+ cond = self.state.stack[-1]
600
+ if "IF_NOT_NONE" in inst.opname:
601
+ cond = f"({cond} is not None)"
602
+ elif "IF_NONE" in inst.opname:
603
+ cond = f"({cond} is None)"
604
+ elif "IF_TRUE" in inst.opname:
605
+ cond = f"{cond}"
606
+ elif "IF_FALSE" in inst.opname:
607
+ cond = f"(not {cond})"
608
+ if_code = f"if {cond}:\n" + add_indentation("continue\n", self.indentation)
609
+ self.state.source_code += if_code
610
+ return
611
+
612
+ max_jump = max(jump_targets)
613
+ max_jump_index = self.index_of(max_jump)
614
+ # else branch might have jumps, we need to find the end of the else
615
+ all_jump_targets = [i.get_jump_target() for i in self.instructions[this_index: max_jump_index] if qualified_jump(i)]
616
+ max_jump_index = self.index_of(max(all_jump_targets))
617
+ last_inst = self.instructions[max_jump_index - 1]
618
+ if "RAISE" in last_inst.opname or "RETURN" in last_inst.opname or "STORE" in last_inst.opname:
619
+ # if-body instructions end with raise/return/store, it is very likely that if-body and else-body don't share any instructions
620
+ pass
621
+ else:
622
+ old_map_jump_index = max_jump_index
623
+ while max_jump_index < len(self.instructions):
624
+ opname = self.instructions[max_jump_index].opname
625
+ if "STORE" in opname or "RETURN" in opname:
626
+ # we want to include the store/return instruction in the if-else block
627
+ max_jump_index += 1
628
+ break
629
+ elif ("JUMP" in opname and max_jump_index > old_map_jump_index) or "FOR_ITER" in opname:
630
+ # we don't want to include the jump instruction in the if-else block
631
+ break
632
+ max_jump_index += 1
633
+ end_index_candidates.append(max_jump_index)
634
+
635
+ end_index = min(end_index_candidates)
636
+
637
+ with self.new_state(fallthrough_stack):
638
+ self.decompile_range(this_index + 1, end_index)
639
+ if_body = self.state.source_code
640
+ if_body = add_indentation(if_body, self.indentation)
641
+ if_end_stack = self.state.stack.copy()
642
+ if_code = f"if {cond}:\n{if_body}"
643
+ self.state.source_code += if_code
644
+
645
+ with self.new_state(jump_stack):
646
+ self.decompile_range(jump_index, end_index)
647
+ else_body = self.state.source_code
648
+ if else_body:
649
+ else_body = add_indentation(else_body, self.indentation)
650
+ else_code = f"else:\n{else_body}"
651
+ self.state.source_code += else_code
652
+
653
+ self.state.stack = if_end_stack
654
+ return end_index
655
+
656
+
657
+ POP_JUMP_IF_TRUE = POP_JUMP_IF_FALSE = generic_jump_if
658
+ POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_FORWARD_IF_FALSE = generic_jump_if
659
+ POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_BACKWARD_IF_FALSE = generic_jump_if
660
+ POP_JUMP_FORWARD_IF_NONE = POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump_if
661
+ POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump_if
662
+ JUMP_IF_TRUE_OR_POP = JUMP_IF_FALSE_OR_POP = generic_jump_if
663
+ POP_JUMP_IF_NOT_NONE = POP_JUMP_BACKWARD_IF_NOT_NONE
664
+ POP_JUMP_IF_NONE = POP_JUMP_BACKWARD_IF_NONE
665
+
666
+ def SETUP_FINALLY(self, inst: Instruction):
667
+ start_index = self.index_of(inst.offset)
668
+ end_index = self.index_of(inst.get_jump_target())
669
+ pop_block_index = [i for i, x in enumerate(
670
+ self.instructions) if x.opname == "POP_BLOCK" and start_index <= i < end_index][-1]
671
+
672
+ try_code = ""
673
+ with self.new_state(self.state.stack):
674
+ self.decompile_range(start_index + 1, pop_block_index)
675
+ try_code = self.state.source_code
676
+ try_code = add_indentation(try_code, self.indentation)
677
+ try_code = "try:\n" + try_code
678
+
679
+ finally_code = ""
680
+ with self.new_state(self.state.stack):
681
+ end_finally_index = [
682
+ i for i, x in enumerate(
683
+ self.instructions) if x.opname == "END_FINALLY" and start_index <= i]
684
+ if end_finally_index:
685
+ end_index = end_finally_index[0]
686
+ finally_end_index = end_index
687
+ if self.instructions[finally_end_index - 1].is_jump():
688
+ finally_end_index -= 1
689
+ self.decompile_range(pop_block_index + 1, finally_end_index)
690
+ finally_code = self.state.source_code
691
+ finally_code = add_indentation(finally_code, self.indentation)
692
+ finally_code = "finally:\n" + finally_code
693
+
694
+ self.state.source_code += try_code + finally_code
695
+ return end_index
696
+
697
+ def SETUP_WITH(self, inst: Instruction):
698
+ """
699
+ with expression as var:
700
+ body
701
+
702
+ is equivalent to:
703
+
704
+ var = expression
705
+ var.__enter__()
706
+ try:
707
+ body
708
+ finally:
709
+ var.__exit__()
710
+
711
+ We find the start of `finally` by `WITH_EXCEPT_START`, and the end of `finally` by `POP_EXCEPT`.
712
+ In early python version, the start is `WITH_CLEANUP_START` and the end is `WITH_CLEANUP_FINISH`.
713
+ """
714
+ start_index = self.index_of(inst.offset)
715
+ with_except_index = [i for i, x in enumerate(
716
+ self.instructions) if x.opname in ["WITH_EXCEPT_START", "WITH_CLEANUP_START"] and i > start_index][-1]
717
+ end_index = with_except_index
718
+ nop_instruction(self.instructions[end_index])
719
+
720
+ # NOP PUSH_EXC_INFO and JUMP_FORWARD
721
+ i = end_index - 1
722
+ while end_index - i <= 2:
723
+ _inst = self.instructions[i]
724
+ if _inst.opname.startswith("JUMP") or _inst.opname == "PUSH_EXC_INFO":
725
+ nop_instruction(_inst)
726
+ i -= 1
727
+
728
+ pop_except_indices = [i for i, x in enumerate(
729
+ self.instructions) if x.opname in ["POP_EXCEPT", "WITH_CLEANUP_FINISH"] and i > end_index]
730
+ if sys.version_info >= (3, 11):
731
+ # Python 3.11 seems to have two `POP_EXCEPT` instructions, not sure why.
732
+ pop_except_index = pop_except_indices[1]
733
+ else:
734
+ pop_except_index = pop_except_indices[0]
735
+ for i in range(end_index, pop_except_index + 1):
736
+ nop_instruction(self.instructions[i])
737
+ tos = self.state.stack[-1]
738
+ temp = self.get_temp_name()
739
+ self.state.stack.append(f"{temp}.__exit__")
740
+ self.state.stack.append(temp)
741
+ with_clause = f"with {tos} as {temp}:\n"
742
+ with_body = ""
743
+ with self.new_state(self.state.stack):
744
+ self.decompile_range(start_index + 1, end_index)
745
+ with_body = self.state.source_code
746
+ with_body = add_indentation(with_body, self.indentation)
747
+ lines = with_body.splitlines()
748
+ ans = []
749
+ for line in lines:
750
+ if f"{temp}.__exit__" in line or "None(None, None)" in line.strip():
751
+ # this is the line that calls __exit__, we need to remove it, as it is managed by `with` statement.
752
+ # `None(None, None)` is used for Python 3.11. Who knows why it loads three Nones but call with 2 args for the following simple code:
753
+ # def f():
754
+ # with a:
755
+ # print(2)
756
+ continue
757
+ ans.append(line)
758
+ with_body = "".join([x + "\n" for x in ans])
759
+
760
+ self.state.source_code += with_clause + with_body
761
+ return pop_except_index + 1
762
+
763
+ BEFORE_WITH = SETUP_WITH
764
+
765
+ def FOR_ITER(self, inst: Instruction):
766
+ start_index = self.index_of(inst.offset)
767
+ end_index = self.index_of(inst.get_jump_target())
768
+
769
+ temp_name = self.get_temp_name()
770
+ for_code = f"for {temp_name} in {self.state.stack.pop()}:\n"
771
+ self.state.stack.append(temp_name)
772
+ last_inst = self.instructions[end_index]
773
+ if last_inst.is_jump() and last_inst.get_jump_target() == inst.offset:
774
+ # if end_index is something like jumping back to for_iter,
775
+ # we should deal with it inside the loop
776
+ end_index += 1
777
+ with self.new_state(self.state.stack, inside_loop=True, loop_start_index=start_index, loop_end_index=end_index):
778
+ self.decompile_range(start_index + 1, end_index)
779
+ code = self.state.source_code
780
+ for_code = for_code + add_indentation(code, self.indentation)
781
+ for_end_stack = self.state.stack.copy()
782
+ self.state.source_code += for_code
783
+ self.state.stack = for_end_stack
784
+ return end_index
785
+
786
+ # ==================== Stack Manipulation Instructions ===================
787
+ def rot_n(self, inst: Instruction):
788
+ if inst.opname == "ROT_N":
789
+ n = inst.argval
790
+ else:
791
+ n = {
792
+ "ROT_TWO": 2,
793
+ "ROT_THREE": 3,
794
+ "ROT_FOUR": 4,
795
+ }[inst.opname]
796
+ values = self.state.stack[-n:]
797
+ values = [values[-1]] + values[:-1]
798
+ self.state.stack[-n:] = values
799
+
800
+ ROT_N = ROT_TWO = ROT_THREE = ROT_FOUR = rot_n
801
+
802
+ def SWAP(self, inst: Instruction):
803
+ n = inst.argval
804
+ tos = self.state.stack[-1]
805
+ value = self.state.stack[- n]
806
+ tos, value = value, tos
807
+ self.state.stack[-1] = tos
808
+ self.state.stack[- n] = value
809
+
810
+ def COPY(self, inst: Instruction):
811
+ # not tested, don't know how to generate this instruction
812
+ n = inst.argval
813
+ value = self.state.stack[-1 - n]
814
+ self.state.stack.append(value)
815
+
816
+ def POP_TOP(self, inst: Instruction):
817
+ self.state.stack.pop()
818
+
819
+ def DUP_TOP(self, inst: Instruction):
820
+ # not tested
821
+ self.state.stack.append(self.state.stack[-1])
822
+
823
+ def DUP_TOP_TWO(self, inst: Instruction):
824
+ # not tested
825
+ tos = self.state.stack[-1]
826
+ tos1 = self.state.stack[-2]
827
+ self.state.stack.append(tos1)
828
+ self.state.stack.append(tos)
829
+
830
+ # ==================== Function Call Instructions =============================
831
+ def KW_NAMES(self, inst: Instruction):
832
+ names = self.code.co_consts[inst.arg]
833
+ self.state.stack.append(repr(names))
834
+
835
+ def CALL(self, inst: Instruction):
836
+ last_inst = [x for x in self.instructions if x.offset < inst.offset]
837
+ has_kw_names = False
838
+ if last_inst:
839
+ if last_inst[-1].opname == "KW_NAMES" or (len(
840
+ last_inst) > 1 and last_inst[-2].opname == "KW_NAMES" and last_inst[-1].opname == "PRECALL"):
841
+ has_kw_names = True
842
+ kw_names = tuple()
843
+ if has_kw_names:
844
+ kw_names = eval(self.state.stack.pop())
845
+ args = [(self.state.stack.pop()) for _ in range(inst.argval)]
846
+ args = args[::-1]
847
+ pos_args = args[:len(args) - len(kw_names)]
848
+ kwargs = args[len(args) - len(kw_names):]
849
+ kwcalls = []
850
+ for name, value in zip(kw_names, kwargs):
851
+ kwcalls.append(f"{name}={value}")
852
+ func = self.state.stack.pop()
853
+ if self.state.stack and self.state.stack[-1] is None:
854
+ self.state.stack.pop()
855
+ if "iter(" in func:
856
+ # Why do we need this? Don't know. But sometimes CPython generates
857
+ # CALL with argval=0, but the function actually needs an arg (for
858
+ # list/set/map comprehension).
859
+ pos_args = [func]
860
+ func = self.state.stack.pop()
861
+ self.state.stack.append(f"{func}({', '.join(pos_args + kwcalls)})")
862
+ self.replace_mutable_tos_with_temp()
863
+
864
+ def generic_call(self, inst: Instruction):
865
+ args = [(self.state.stack.pop()) for _ in range(inst.argval)]
866
+ args = args[::-1]
867
+ func = self.state.stack.pop()
868
+ self.state.stack.append(f"{func}({', '.join(args)})")
869
+ self.replace_mutable_tos_with_temp()
870
+
871
+ CALL_FUNCTION = CALL_METHOD = generic_call
872
+
873
+ def CALL_FUNCTION_KW(self, inst: Instruction):
874
+ kw_args = eval(self.state.stack.pop())
875
+ kw_vals = [(self.state.stack.pop()) for _ in range(len(kw_args))]
876
+ kw_vals.reverse()
877
+ kwcalls = []
878
+ for name, val in zip(kw_args, kw_vals):
879
+ kwcalls.append(f"{name}={val}")
880
+ pos_args = [(self.state.stack.pop())
881
+ for _ in range(inst.argval - len(kw_args))]
882
+ pos_args = pos_args[::-1]
883
+ func = self.state.stack.pop()
884
+ self.state.stack.append(f"{func}({', '.join(pos_args + kwcalls)})")
885
+ self.replace_mutable_tos_with_temp()
886
+
887
+ def CALL_FUNCTION_EX(self, inst: Instruction):
888
+ if inst.argval == 0:
889
+ args = self.state.stack.pop()
890
+ func = self.state.stack.pop()
891
+ self.state.stack.append(f"{func}(*{args})")
892
+ elif inst.argval == 1:
893
+ kw_args = self.state.stack.pop()
894
+ args = self.state.stack.pop()
895
+ func = self.state.stack.pop()
896
+ self.state.stack.append(f"{func}(*{args}, **{kw_args})")
897
+ self.replace_mutable_tos_with_temp()
898
+
899
+ def CALL_INTRINSIC_1(self, inst: Instruction):
900
+ if inst.argrepr in [
901
+ "INTRINSIC_1_INVALID",
902
+ "INTRINSIC_IMPORT_STAR",
903
+ "INTRINSIC_STOPITERATION_ERROR",
904
+ "INTRINSIC_ASYNC_GEN_WRAP"]:
905
+ # invalid intrinsic, skip
906
+ pass
907
+ elif inst.argrepr in ["INTRINSIC_TYPEVAR", "INTRINSIC_PARAMSPEC", "INTRINSIC_TYPEVARTUPLE", "INTRINSIC_SUBSCRIPT_GENERIC", "INTRINSIC_TYPEALIAS"]:
908
+ # not tested, skip
909
+ pass
910
+ elif inst.argrepr == "INTRINSIC_PRINT":
911
+ self.state.source_code += f"print({self.state.stack.pop()})\n"
912
+ self.state.stack.append("None")
913
+ elif inst.argrepr == "INTRINSIC_UNARY_POSITIVE":
914
+ self.state.stack[-1] = f"+{self.state.stack[-1]}"
915
+ elif inst.argrepr == "INTRINSIC_LIST_TO_TUPLE":
916
+ return self.LIST_TO_TUPLE(inst)
917
+
918
+
919
+ # ==================== Container Related Instructions (tuple, list, set, d
920
+
921
+ def UNPACK_SEQUENCE(self, inst: Instruction):
922
+ # sequence can be tuple, list, or even generator
923
+ # we cannot directly use indexing to get the elements
924
+ # because the sequence might be a generator (not subscriptable)
925
+ # instead, we use a temporary variable to store the unpacked elements
926
+
927
+ # e.g. `a, b = (None for _ in (1, 2))`
928
+ # will be transformed into:
929
+ # __temp_1 = (None for _ in (1, 2))
930
+ # __temp_2, __temp_3 = __temp_1
931
+ # a = __temp_2
932
+ # b = __temp_3
933
+ varname = self.state.stack.pop()
934
+ tmp_names = []
935
+ for i in range(inst.argval):
936
+ tmp_names.append(self.get_temp_name())
937
+ # NOTE: even if there is only one element, we still need to unpack it
938
+ # a = b is different from a, = b
939
+ lhs = "".join([f"{x}, " for x in tmp_names])
940
+ self.state.source_code += lhs + f"= {varname}\n"
941
+ for name in tmp_names[::-1]:
942
+ self.state.stack.append(name)
943
+
944
+ def UNPACK_EX(self, inst: Instruction):
945
+ varname = self.state.stack.pop()
946
+ tmp_names = []
947
+ for i in range(inst.argval):
948
+ tmp_names.append(self.get_temp_name())
949
+ star_name = self.get_temp_name()
950
+ self.state.source_code += ", ".join(tmp_names) + f", *{star_name}" + f" = {varname}\n"
951
+ self.state.stack.append(star_name)
952
+ for name in tmp_names[::-1]:
953
+ self.state.stack.append(name)
954
+
955
+ def BUILD_SLICE(self, inst: Instruction):
956
+ tos = self.state.stack.pop()
957
+ tos1 = self.state.stack.pop()
958
+ if inst.argval == 2:
959
+ self.state.stack.append(f"slice({tos1}, {tos})")
960
+ elif inst.argval == 3:
961
+ tos2 = self.state.stack.pop()
962
+ self.state.stack.append(f"slice({tos2}, {tos1}, {tos})")
963
+
964
+ def build_tuple(self, inst: Instruction):
965
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
966
+ args = args[::-1]
967
+ if "UNPACK" in inst.opname:
968
+ args = [f"*{arg}" for arg in args]
969
+ if inst.argval == 1:
970
+ self.state.stack.append(f"({args[0]},)")
971
+ else:
972
+ self.state.stack.append(f"({', '.join(args)})")
973
+
974
+ BUILD_TUPLE = BUILD_TUPLE_UNPACK = BUILD_TUPLE_UNPACK_WITH_CALL = build_tuple
975
+
976
+ def build_list(self, inst: Instruction):
977
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
978
+ args = args[::-1]
979
+ if "UNPACK" in inst.opname:
980
+ args = [f"*{arg}" for arg in args]
981
+ self.state.stack.append(f"[{', '.join(args)}]")
982
+ self.replace_mutable_tos_with_temp()
983
+
984
+ BUILD_LIST = BUILD_LIST_UNPACK = build_list
985
+
986
+ def build_set(self, inst: Instruction):
987
+ ans = ""
988
+ if inst.argval == 0:
989
+ ans = "set()"
990
+ else:
991
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
992
+ args = args[::-1]
993
+ if "UNPACK" in inst.opname:
994
+ args = [f"*{arg}" for arg in args]
995
+ ans = f"{{{', '.join(args)}}}"
996
+ self.state.stack.append(ans)
997
+ self.replace_mutable_tos_with_temp()
998
+
999
+ BUILD_SET = BUILD_SET_UNPACK = build_set
1000
+
1001
+ def build_map_unpack(self, inst: Instruction):
1002
+ if inst.argval == 0:
1003
+ self.state.stack.append("dict()")
1004
+ else:
1005
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
1006
+ args = args[::-1]
1007
+ args = [f"**{arg}" for arg in args]
1008
+ self.state.stack.append(f"{{{', '.join(args)}}}")
1009
+ self.replace_mutable_tos_with_temp()
1010
+
1011
+ BUILD_MAP_UNPACK = BUILD_MAP_UNPACK_WITH_CALL = build_map_unpack
1012
+
1013
+ def BUILD_MAP(self, inst: Instruction):
1014
+ args = [self.state.stack.pop() for _ in range(inst.argval * 2)]
1015
+ args = args[::-1]
1016
+ keys = args[::2]
1017
+ values = args[1::2]
1018
+ self.state.stack.append(
1019
+ f"{{{', '.join([f'{k}: {v}' for k, v in zip(keys, values)])}}}")
1020
+ self.replace_mutable_tos_with_temp()
1021
+
1022
+ def BUILD_CONST_KEY_MAP(self, inst: Instruction):
1023
+ keys = eval(self.state.stack.pop())
1024
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
1025
+ values = args[::-1]
1026
+ self.state.stack.append(
1027
+ f"{{{', '.join([f'{k}: {v}' for k, v in zip(keys, values)])}}}")
1028
+ self.replace_mutable_tos_with_temp()
1029
+
1030
+ def BUILD_STRING(self, inst: Instruction):
1031
+ args = [self.state.stack.pop() for _ in range(inst.argval)]
1032
+ args = args[::-1]
1033
+ values = " + ".join(args)
1034
+ self.state.stack.append(values)
1035
+
1036
+ def LIST_TO_TUPLE(self, inst: Instruction):
1037
+ item = self.state.stack.pop()
1038
+ self.state.stack.append(f"tuple({item})")
1039
+
1040
+ def LIST_EXTEND(self, inst: Instruction):
1041
+ assert inst.argval == 1, "Only tested for argval==1"
1042
+ values = self.state.stack.pop()
1043
+ temp = self.replace_mutable_tos_with_temp()
1044
+ self.state.source_code += f"{temp}.extend({values})\n"
1045
+
1046
+ def LIST_APPEND(self, inst: Instruction):
1047
+ if inst.argval == 1:
1048
+ # it should be a bug, the tos should be the value. fix it anyway.
1049
+ inst.argval += 1
1050
+ container = self.state.stack[-inst.argval]
1051
+ value = self.state.stack.pop()
1052
+ self.state.source_code += f"{container}.append({value})\n"
1053
+
1054
+ def generic_update(self, inst: Instruction):
1055
+ assert inst.argval == 1, "Only tested for argval==1"
1056
+ values = self.state.stack.pop()
1057
+ temp = self.replace_mutable_tos_with_temp()
1058
+ self.state.source_code += f"{temp}.update({values})\n"
1059
+
1060
+ SET_UPDATE = DICT_UPDATE = DICT_MERGE = generic_update
1061
+
1062
+ def SET_ADD(self, inst: Instruction):
1063
+ if inst.argval == 1:
1064
+ # it should be a bug, the tos should be the value. fix it anyway.
1065
+ inst.argval += 1
1066
+ container = self.state.stack[-inst.argval]
1067
+ value = self.state.stack.pop()
1068
+ self.state.source_code += f"{container}.add({value})\n"
1069
+
1070
+ def MAP_ADD(self, inst: Instruction):
1071
+ container = self.state.stack[-inst.argval - 1]
1072
+ # see https://docs.python.org/3.10/library/dis.html#opcode-MAP_ADD
1073
+ if sys.version_info >= (3, 8):
1074
+ value = self.state.stack.pop()
1075
+ key = self.state.stack.pop()
1076
+ else:
1077
+ key = self.state.stack.pop()
1078
+ value = self.state.stack.pop()
1079
+ self.state.source_code += f"{container}.__setitem__({key}, {value})\n"
1080
+
1081
+ # ==================== Misc Instructions =============================
1082
+ def RAISE_VARARGS(self, inst: Instruction):
1083
+ if inst.argval == 0:
1084
+ self.state.source_code += "raise\n"
1085
+ elif inst.argval == 1:
1086
+ self.state.source_code += f"raise {self.state.stack.pop()}\n"
1087
+ elif inst.argval == 2:
1088
+ tos = self.state.stack.pop()
1089
+ tos1 = self.state.stack.pop()
1090
+ self.state.source_code += f"raise {tos1} from {tos}\n"
1091
+
1092
+ def FORMAT_VALUE(self, inst: Instruction):
1093
+ func, spec = inst.argval
1094
+ if spec:
1095
+ form_spec = self.state.stack.pop()
1096
+ value = self.state.stack.pop()
1097
+ self.state.stack.append(f"format({value}, {form_spec})")
1098
+ else:
1099
+ value = self.state.stack.pop()
1100
+ func = str if func is None else func
1101
+ self.state.stack.append(f"{func.__name__}({value})")
1102
+
1103
+
1104
+ def decompile_range(self, start: int, end: int):
1105
+ try:
1106
+ running_index = start
1107
+ while running_index < end:
1108
+ inst = self.instructions[running_index]
1109
+ method = getattr(
1110
+ Decompiler,
1111
+ inst.opname,
1112
+ Decompiler.unimplemented_instruction)
1113
+ output = method(self, inst)
1114
+ if output:
1115
+ running_index = output
1116
+ else:
1117
+ running_index += 1
1118
+ except Exception as e:
1119
+ raise DecompilationError(
1120
+ f"Failed to decompile instruction {inst} in {self.code.co_name}") from e
1121
+
1122
+ def index_of(self, offset: int):
1123
+ for idx, inst in enumerate(self.instructions):
1124
+ if inst.offset == offset:
1125
+ return idx
1126
+ raise ValueError(f"Cannot find instruction with offset {offset}")
1127
+
1128
+ @staticmethod
1129
+ def cleanup_instructions(code, instructions: List[Instruction]):
1130
+ propagate_line_nums(instructions)
1131
+ simplify_finally_statement(instructions)
1132
+ nop_unreachable_bytecode(code, instructions)
1133
+
1134
+ def __init__(self, code: Union[CodeType, Callable]):
1135
+ if callable(code):
1136
+ from depyf.utils import get_code_owner
1137
+ code = get_code_owner(code).__code__
1138
+ self.code = code
1139
+ instructions = list(convert_instruction(_)
1140
+ for _ in dis.get_instructions(code))
1141
+ Decompiler.cleanup_instructions(code, instructions)
1142
+ self.instructions = instructions
1143
+ self.state = DecompilerState(source_code="", stack=[])
1144
+
1145
+ def get_temp_name(self):
1146
+ Decompiler.temp_count += 1
1147
+ return f"{self.temp_prefix}{Decompiler.temp_count}"
1148
+
1149
+ def replace_mutable_tos_with_temp(self):
1150
+ ans = self.state.stack.pop()
1151
+ temp_name = self.get_temp_name()
1152
+ self.state.source_code += f"{temp_name} = {ans}\n"
1153
+ self.state.stack.append(temp_name)
1154
+ return temp_name
1155
+
1156
+ @staticmethod
1157
+ def supported_opnames():
1158
+ opnames = []
1159
+ for x in dis.opname:
1160
+ if getattr(
1161
+ Decompiler,
1162
+ x,
1163
+ Decompiler.unimplemented_instruction) is not Decompiler.unimplemented_instruction:
1164
+ opnames.append(x)
1165
+ return opnames
1166
+
1167
+ @functools.lru_cache(maxsize=None)
1168
+ def decompile(
1169
+ self,
1170
+ indentation=4,
1171
+ temp_prefix: str = "__temp_",
1172
+ overwite_fn_name: Optional[str] = None) -> str:
1173
+ try:
1174
+ self.indentation = indentation
1175
+ self.temp_prefix = temp_prefix
1176
+ self.decompile_range(0, len(self.instructions))
1177
+ source_code = self.state.source_code
1178
+ # the header might have invalid function name in torchdynamo. only
1179
+ # optimize the function body.
1180
+ source_code = remove_some_temp(
1181
+ source_code, self.temp_prefix, indentation)
1182
+ header = get_function_signature(self.code, overwite_fn_name)
1183
+ # we cannot rely on `co_names`. For example, `from math import sqrt` will make `math` and `sqrt` in `co_names`.
1184
+ global_names = set(inst.argval for inst in dis.get_instructions(self.code) if inst.opname == "STORE_GLOBAL")
1185
+ global_statements = "global " + ", ".join(
1186
+ global_names) + "\n" if global_names else ""
1187
+ nonlocal_statement = "nonlocal " + ", ".join(
1188
+ self.code.co_freevars) + "\n" if self.code.co_freevars else ""
1189
+ source_code = global_statements + nonlocal_statement + source_code
1190
+ source_code = header + add_indentation(source_code, indentation)
1191
+ return source_code
1192
+ except DecompilationError:
1193
+ raise
1194
+ except Exception as e:
1195
+ raise DecompilationError(
1196
+ f"Failed to decompile {self.code.co_name}") from e
1197
+
1198
+ @staticmethod
1199
+ def decompile_and_compile_like(
1200
+ code_to_decompile: CodeType,
1201
+ reference_code: CodeType,
1202
+ indentation=4,
1203
+ temp_prefix: str = "__temp_",
1204
+ filepath_template: Optional[str] = None) -> CodeType:
1205
+
1206
+ # first, decompile the code into source code, with function name `__place_holder__`
1207
+ src = Decompiler(code_to_decompile).decompile(indentation=indentation, temp_prefix=temp_prefix, overwite_fn_name="__place_holder__")
1208
+
1209
+ # fix the freevars/cellvars in the source code
1210
+ from depyf.code_transform import fix_irregular_code
1211
+ # check https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/4 for why we need to prepare freevars like `reference_code` rather than `code`
1212
+ src = fix_irregular_code(reference_code, src)
1213
+
1214
+ if filepath_template is None:
1215
+ func_name = reference_code.co_name
1216
+ src = src.replace("__place_holder__", func_name)
1217
+ filename = "noname"
1218
+ else:
1219
+ src_body = src[src.find("("):]
1220
+ if reference_code.co_freevars:
1221
+ src_body = src_body[src_body.find("("):]
1222
+
1223
+ count = 0
1224
+ while True:
1225
+ filename = filepath_template % count
1226
+ if os.path.exists(filename):
1227
+ existing_code = open(filename, "r").read()
1228
+ existing_code_body = existing_code[existing_code.find("("):]
1229
+ if reference_code.co_freevars:
1230
+ existing_code_body = existing_code_body[existing_code_body.find("("):]
1231
+ if src_body == existing_code_body:
1232
+ # the same code body is found, we do not need to dump the code again.
1233
+ src = existing_code
1234
+ break
1235
+ else:
1236
+ count += 1
1237
+ else:
1238
+ func_name = filename.split(os.path.sep)[-1].split(".")[0]
1239
+ src = src.replace("__place_holder__", func_name)
1240
+ with open(filename, "w") as f:
1241
+ f.write(src)
1242
+ break
1243
+
1244
+ func_name = filename.split(os.path.sep)[-1].split(".")[0]
1245
+
1246
+ from depyf.utils import collect_all_code_objects
1247
+ transformed_code = compile(src, filename=filename, mode="exec")
1248
+ transformed_codes = collect_all_code_objects(transformed_code)
1249
+ decompiled_and_compiled_back_code = [x for x in transformed_codes if x.co_name == func_name][0]
1250
+
1251
+ # torch.compile might hold random non-constant values in `new_code.co_consts` that cannot
1252
+ # be represented in source code. During decompliation, we treat them as `__co_consts[i]`,
1253
+ # a string that represents the constant in the original code object.
1254
+ # We need to replace them with the actual constant in the original code object, so that
1255
+ # the decompiled and compiled back code object can be used for execution.
1256
+ updated_consts = []
1257
+ for i, x in enumerate(decompiled_and_compiled_back_code.co_consts):
1258
+ if isinstance(x, str) and x.startswith("__co_consts"):
1259
+ index = int(x.split("[")[-1][:-1]) # __co_consts[0] -> 0
1260
+ updated_consts.append(code_to_decompile.co_consts[index])
1261
+ else:
1262
+ updated_consts.append(x)
1263
+
1264
+ decompiled_and_compiled_back_code = decompiled_and_compiled_back_code.replace(co_consts=tuple(updated_consts))
1265
+
1266
+ return decompiled_and_compiled_back_code
1267
+
1268
+ def __hash__(self):
1269
+ # see https://github.com/thuml/depyf/pull/21
1270
+ return id(self.code)
1271
+
1272
+ def __eq__(self, other):
1273
+ return hash(self) == hash(other)
1274
+
1275
+ def decompile(code: Union[CodeType, Callable]) -> str:
1276
+ """Decompile any callable or code object into Python source code.
1277
+ It is especially useful for some dynamically generated code, like ``torch.compile``,
1278
+ or ``dataclasses``.
1279
+
1280
+ Example usage:
1281
+
1282
+ .. code-block:: python
1283
+
1284
+ from dataclasses import dataclass
1285
+ @dataclass
1286
+ class Data:
1287
+ x: int
1288
+ y: float
1289
+
1290
+ import depyf
1291
+ print(depyf.decompile(Data.__init__))
1292
+ print(depyf.decompile(Data.__eq__))
1293
+
1294
+ Output:
1295
+
1296
+ .. code-block:: python
1297
+
1298
+ def __init__(self, x, y):
1299
+ self.x = x
1300
+ self.y = y
1301
+ return None
1302
+
1303
+ def __eq__(self, other):
1304
+ if other.__class__ is self.__class__:
1305
+ return (self.x, self.y) == (other.x, other.y)
1306
+ return NotImplemented
1307
+
1308
+ The output source code is semantically equivalent to the function, but not syntactically the same. It verbosely adds many details that are hidden in the Python code. For example, the above output code of ``__init__`` explicitly returns ``None``, which is typically ignored.
1309
+
1310
+ Another detail is that the output code of ``__eq__`` returns ``NotImplemented`` instead of raising ``NotImplemented`` exception when the types are different. At the first glance, it seems to be a bug. However, it is actually the correct behavior. The ``__eq__`` method should return ``NotImplemented`` when the types are different, so that the other object can try to compare with the current object. See `the Python documentation <https://docs.python.org/3/library/numbers.html#implementing-the-arithmetic-operations>`_ for more details.
1311
+ """
1312
+ return Decompiler(code).decompile()
.venv/lib/python3.11/site-packages/depyf/explain/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from depyf.explain.utils import DynamoOptimizationResult
2
+
3
+ from torch._dynamo.eval_frame import innermost_fn
4
+
5
+ from typing import List, Callable, Dict, Union, Set
6
+ from types import CodeType
7
+
8
+
9
+ def _extract_artifacts(original_code: CodeType, module):
10
+ result = DynamoOptimizationResult(original_code, None, module)
11
+ return result
12
+
13
+ def dump_src(original_code: CodeType, module):
14
+ from depyf.explain.global_variables import data
15
+ assert data["is_inside_prepare_debug"], "`dump_src` must be used inside `depyf.prepare_debug`."
16
+ artifacts = _extract_artifacts(original_code, module)
17
+ return artifacts.to_src()
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.27 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enable_debugging.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enhance_logging.cpython-311.pyc ADDED
Binary file (3.81 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/global_variables.cpython-311.pyc ADDED
Binary file (911 Bytes). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched___call__.cpython-311.pyc ADDED
Binary file (737 Bytes). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched__exec_with_source.cpython-311.pyc ADDED
Binary file (1.43 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_boxed_run.cpython-311.pyc ADDED
Binary file (387 Bytes). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_lazy_format_graph_code.cpython-311.pyc ADDED
Binary file (4.47 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_load_by_key_path.cpython-311.pyc ADDED
Binary file (1.18 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/utils.cpython-311.pyc ADDED
Binary file (20.3 kB). View file
 
.venv/lib/python3.11/site-packages/depyf/explain/enable_debugging.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .patched_boxed_run import patched_boxed_run
2
+ from .patched_lazy_format_graph_code import patched_lazy_format_graph_code
3
+ from .patched_load_by_key_path import patched_load_by_key_path
4
+ from .patched__exec_with_source import patched__exec_with_source
5
+ from typing import List, Tuple, Dict, Union, Callable, Optional, Any
6
+
7
+ import contextlib
8
+ import warnings
9
+ import traceback
10
+
11
+ import dataclasses
12
+ import itertools
13
+ import sys
14
+ import os
15
+ import inspect
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class DebuggableHook(object):
20
+ dump_src_dir: str
21
+ log_bytecode: bool
22
+ optimized_code_and_module: List =dataclasses.field(default_factory=list, init=False)
23
+
24
+ def __call__(self, code, new_code):
25
+ frame = sys._getframe()
26
+ import os
27
+ while True:
28
+ frame = frame.f_back
29
+ code_name = frame.f_code.co_name
30
+ file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
31
+ if code_name == "_compile" and file_name == "convert_frame.py":
32
+ break
33
+ frame = frame.f_locals["frame"]
34
+ assert frame.f_code == code
35
+ self.optimized_code_and_module.append([code, frame.f_globals])
36
+ from depyf.decompiler import DecompilationError
37
+ try:
38
+ import os
39
+ # replace " "/"<"/">"/"." with "_"
40
+ func_name = code.co_name.replace(".", "_").replace("<", "_").replace(">", "_").replace(" ", "_")
41
+ filepath_template = os.path.join(
42
+ self.dump_src_dir,
43
+ f"__transformed_code_%s_for_{func_name}.py")
44
+
45
+ from depyf.explain.utils import lock_on_file
46
+ from depyf.decompiler import Decompiler
47
+
48
+ # function name and file name are related.
49
+ with lock_on_file(filepath_template):
50
+ decompiled_and_compiled_back_code = Decompiler.decompile_and_compile_like(code_to_decompile=new_code, reference_code=code, filepath_template=filepath_template)
51
+ filename = decompiled_and_compiled_back_code.co_filename
52
+ if self.log_bytecode:
53
+ with lock_on_file(filename):
54
+ import dill
55
+ # code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`.
56
+ # see https://github.com/pytorch/pytorch/issues/116013 for more details.
57
+ with contextlib.suppress(Exception):
58
+ dill.dump(code, open(filename + ".original_bytecode", "wb"))
59
+
60
+ with contextlib.suppress(Exception):
61
+ dill.dump(new_code, open(filename + ".transformed_bytecode", "wb"))
62
+
63
+ with contextlib.suppress(Exception):
64
+ dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb"))
65
+
66
+ # this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487
67
+ from torch._dynamo.utils import orig_code_map
68
+ from torch._dynamo.convert_frame import output_codes
69
+ output_codes.add(decompiled_and_compiled_back_code)
70
+ orig_code_map[decompiled_and_compiled_back_code] = code
71
+
72
+ return decompiled_and_compiled_back_code
73
+ except (DecompilationError, SyntaxError) as e:
74
+ from io import StringIO
75
+ string_io = StringIO()
76
+ import dis
77
+ print("There is a problem when decompiling and compiling the following code:", file=string_io)
78
+ dis.dis(new_code, file=string_io)
79
+ print("Please consider submitting an issue to https://github.com/thuml/depyf .", file=string_io)
80
+ # do not stop the program for decompilation error and compile error
81
+ warnings.warn(string_io.getvalue())
82
+ traceback.print_exc()
83
+
84
+ @contextlib.contextmanager
85
+ def patch(parent, name, value):
86
+ old_value = getattr(parent, name, None)
87
+ if old_value is not None:
88
+ setattr(parent, name, value)
89
+ try:
90
+ yield
91
+ finally:
92
+ if old_value is not None:
93
+ setattr(parent, name, old_value)
94
+
95
+
96
+ @contextlib.contextmanager
97
+ def enable_bytecode_hook(hook):
98
+ import torch
99
+ handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
100
+ try:
101
+ yield
102
+ finally:
103
+ handle.remove()
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):
108
+ """
109
+ A context manager to dump debugging information for torch.compile.
110
+ It should wrap the code that actually triggers the compilation, rather than
111
+ the code that applies ``torch.compile``.
112
+
113
+ Example:
114
+
115
+ .. code-block:: python
116
+
117
+ import torch
118
+
119
+ @torch.compile
120
+ def toy_example(a, b):
121
+ x = a / (torch.abs(a) + 1)
122
+ if b.sum() < 0:
123
+ b = b * -1
124
+ return x * b
125
+
126
+ def main():
127
+ for _ in range(100):
128
+ toy_example(torch.randn(10), torch.randn(10))
129
+
130
+ if __name__ == "__main__":
131
+ # main()
132
+ # surround the code you want to run inside `with depyf.prepare_debug`
133
+ import depyf
134
+ with depyf.prepare_debug("./dump_src_dir"):
135
+ main()
136
+
137
+ After running the code, you will find the dumped information in the directory ``dump_src_dir``. The details are organized into the following:
138
+
139
+ - ``full_code_for_xxx.py`` for each function using torch.compile
140
+ - ``__transformed_code_for_xxx.py`` for Python code associated with each graph.
141
+ - ``__transformed_code_for_xxx.py.xxx_bytecode`` for Python bytecode, dumped code object, can be loaded via ``dill.load(open("/path/to/file", "wb"))``. Note that the load function might import some modules like transformers. Make sure you have these modules installed.
142
+ - ``__compiled_fn_xxx.py`` for each computation graph and its optimization:
143
+ - ``Captured Graph``: a plain forward computation graph
144
+ - ``Joint Graph``: joint forward-backward graph from AOTAutograd
145
+ - ``Forward Graph``: forward graph from AOTAutograd
146
+ - ``Backward Graph``: backward graph from AOTAutograd
147
+ - ``kernel xxx``: compiled CPU/GPU kernel wrapper from Inductor.
148
+
149
+ Arguments:
150
+
151
+ - ``dump_src_dir``: the directory to dump the source code.
152
+ - ``clean_wild_fx_code``: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
153
+ - ``log_bytecode``: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code).
154
+ """
155
+
156
+ if not isinstance(dump_src_dir, str):
157
+ raise RuntimeError('''You are using an obsolete usage style`depyf.prepare_debug(func=function, dump_src_dir="/path")`. Please use `depyf.prepare_debug(dump_src_dir="/path")` instead, which will automatically capture all compiled functions.''')
158
+
159
+ import os
160
+ import torch
161
+
162
+ current_line_number = inspect.currentframe().f_lineno + 1
163
+ warnings.warn_explicit(f"{__file__}:{current_line_number}: You are trying to debug `torch.compile`. Please make sure the code runs multiple times to cover all the possible branches.", UserWarning, "", 0)
164
+
165
+ from depyf.utils import safe_create_directory
166
+
167
+ if not os.path.exists(dump_src_dir):
168
+ safe_create_directory(dump_src_dir)
169
+
170
+ dump_src_dir = os.path.abspath(dump_src_dir)
171
+
172
+ from .global_variables import data
173
+
174
+ data["dump_src_dir"] = dump_src_dir
175
+ data["unpatched__exec_with_source"] = torch.fx.graph_module._exec_with_source
176
+ data["unpatched_load_by_key_path"] = torch._inductor.codecache.PyCodeCache.load_by_key_path
177
+ data["unpatched___call__"] = torch._dynamo.eval_frame.OptimizeContext.__call__
178
+ data["is_inside_prepare_debug"] = True
179
+
180
+ bytecode_hook = DebuggableHook(dump_src_dir, log_bytecode)
181
+
182
+ # patch some functions
183
+ with patch(torch.fx.graph_module, "_exec_with_source", patched__exec_with_source), \
184
+ patch(torch._inductor.codecache.PyCodeCache, "load_by_key_path", patched_load_by_key_path), \
185
+ patch(torch._dynamo.utils.lazy_format_graph_code, "__code__", patched_lazy_format_graph_code.__code__):
186
+ # we have to directly manipulate the code object, since the function has been imported in many places.
187
+ # simply replacing torch._dynamo.utils.lazy_format_graph_code does not work for those functions.
188
+ # Note: `unitest.mock.patch` does not work here, since it will not
189
+ # patch the code object. (it will try to delete the code object and
190
+ # then set a new code object. The `delattr` will raise an error.)
191
+
192
+ # enable bytecode hook
193
+ with enable_bytecode_hook(bytecode_hook):
194
+ try:
195
+ yield
196
+ finally:
197
+
198
+ code_names = {x[0].co_name for x in bytecode_hook.optimized_code_and_module}
199
+ for code, module in bytecode_hook.optimized_code_and_module:
200
+ if code.co_name.startswith("resume_in_") and any(f"resume_in_{name}" in code.co_name for name in code_names):
201
+ continue
202
+ # https://github.com/pytorch/pytorch/pull/118201 introduces `torch_dynamo_resume_in_` names.
203
+ if code.co_name.startswith("torch_dynamo_resume_in_") and any(f"torch_dynamo_resume_in_{name}" in code.co_name for name in code_names):
204
+ continue
205
+ from depyf.explain import dump_src
206
+ from depyf.explain.utils import write_code_to_file_template
207
+ from torch._dynamo.eval_frame import innermost_fn, _debug_get_cache_entry_list
208
+ entries = _debug_get_cache_entry_list(code)
209
+ if not entries:
210
+ current_line_number = inspect.currentframe().f_lineno + 1
211
+ warnings.warn_explicit(f"{__file__}:{current_line_number}: Code object {code} is compiled but does not have any compiled cache entries. Probably some torch.nn.Module instances are destroyed too early. It is recommended to make sure the torch.nn.Module instances exist after `with depyf.prepare_debug`.", UserWarning, "", 0)
212
+ full_src = dump_src(code, module)
213
+ filepath_template = os.path.join(dump_src_dir, f"full_code_for_{code.co_name}_%s.py")
214
+ full_code_path = write_code_to_file_template(full_src, filepath_template)
215
+
216
+ for file in os.listdir(dump_src_dir):
217
+ name = file.split(os.path.sep)[-1]
218
+ # remove *.lock file and possibly fx_graph_code* file
219
+ if (clean_wild_fx_code and name.startswith("fx_graph_code")) or name.endswith(".lock"):
220
+ try:
221
+ # multiple processes may try to remove the same file.
222
+ os.remove(os.path.join(dump_src_dir, file))
223
+ except OSError:
224
+ pass
225
+
226
+ data["is_inside_prepare_debug"] = False
227
+
228
+ @contextlib.contextmanager
229
+ def debug():
230
+ """
231
+ A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix ``full_code_for_`` in the ``dump_src_dir`` argument of :func:`depyf.prepare_debug`, and set breakpoints in their separate ``__transformed_code_`` files according to the function name. Then continue your debugging.
232
+ """
233
+ from .global_variables import data
234
+ if data["is_inside_prepare_debug"]:
235
+ raise RuntimeError("You cannot use `depyf.debug` inside `depyf.prepare_debug`.")
236
+ dump_src_dir = data["dump_src_dir"]
237
+ import torch
238
+ # after https://github.com/pytorch/pytorch/pull/131258
239
+ # torch._dynamo.eval_frame.set_eval_frame is not available in the module
240
+ # we need to directly access it from the `_C` extension.
241
+ callback = torch._C._dynamo.eval_frame.set_eval_frame(False)
242
+ # sometimes pytorch use Interpreter to run node by node. This cannot be debugged.
243
+ # we patch this function to run the graph function directly.
244
+ with patch(torch.fx.Interpreter.boxed_run, "__code__", patched_boxed_run.__code__):
245
+ try:
246
+ msg = f"`depyf` places a breakpoint here to pause the program. You can check the full source code in files with prefix `full_code_for_` in {dump_src_dir} first, and set breakpoints in their separate files according to the function name. Then continue your debugging."
247
+ print(msg)
248
+ breakpoint()
249
+ yield
250
+ finally:
251
+ torch._C._dynamo.eval_frame.set_eval_frame(callback)
.venv/lib/python3.11/site-packages/depyf/explain/enhance_logging.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from depyf.decompiler import decompile, DecompilationError
3
+
4
+
5
+ def pytorch_bytecode_src_hook(code: types.CodeType, new_code: types.CodeType):
6
+ import torch
7
+ bytecode_log = torch._logging.getArtifactLogger(
8
+ "torch._dynamo.convert_frame", "bytecode"
9
+ )
10
+ import logging
11
+
12
+ if bytecode_log.isEnabledFor(logging.DEBUG):
13
+ try:
14
+ decompiled_src = decompile(new_code)
15
+ bytecode_log.debug("possible source code:")
16
+ bytecode_log.debug(decompiled_src)
17
+ except DecompilationError as e:
18
+ bytecode_log.debug("Decompilation fails due to: %s", str(e))
19
+ finally:
20
+ bytecode_log.debug(
21
+ "If you find the decompiled code is wrong,"
22
+ "please submit an issue at "
23
+ "https://github.com/thuml/depyf/issues."
24
+ )
25
+
26
+
27
+ _handle = None
28
+
29
+
30
+ def install():
31
+ """
32
+ Install the bytecode hook for PyTorch, integrate into PyTorch's logging system.
33
+
34
+ Example:
35
+
36
+ .. code-block:: python
37
+
38
+ import torch
39
+ import depyf
40
+ depyf.install()
41
+ # anything with torch.compile
42
+ @torch.compile
43
+ def f(a, b):
44
+ return a + b
45
+ f(torch.tensor(1), torch.tensor(2))
46
+
47
+ Turn on bytecode log by ``export TORCH_LOGS="+bytecode"``, and execute the script.
48
+ We will see the decompiled source code in the log:
49
+
50
+ .. code-block:: text
51
+
52
+ ORIGINAL BYTECODE f test.py line 5
53
+ 7 0 LOAD_FAST 0 (a)
54
+ 2 LOAD_FAST 1 (b)
55
+ 4 BINARY_ADD
56
+ 6 RETURN_VALUE
57
+
58
+
59
+ MODIFIED BYTECODE f test.py line 5
60
+ 5 0 LOAD_GLOBAL 0 (__compiled_fn_1)
61
+ 2 LOAD_FAST 0 (a)
62
+ 4 LOAD_FAST 1 (b)
63
+ 6 CALL_FUNCTION 2
64
+ 8 UNPACK_SEQUENCE 1
65
+ 10 RETURN_VALUE
66
+
67
+
68
+ possible source code:
69
+ def f(a, b):
70
+ __temp_2, = __compiled_fn_1(a, b)
71
+ return __temp_2
72
+
73
+ If you find the decompiled code is wrong,please submit an issue at https://github.com/thuml/depyf/issues.
74
+
75
+ To uninstall the hook, use :func:`depyf.uninstall()`.
76
+ """
77
+ import torch
78
+ global _handle
79
+ if _handle is not None:
80
+ return
81
+ _handle = torch._dynamo.convert_frame.register_bytecode_hook(
82
+ pytorch_bytecode_src_hook)
83
+
84
+
85
+ def uninstall():
86
+ """
87
+ Uninstall the bytecode hook for PyTorch.
88
+ Should be called after :func:`depyf.install()`.
89
+ """
90
+ global _handle
91
+ if _handle is None:
92
+ return
93
+ _handle.remove()
94
+ _handle = None
.venv/lib/python3.11/site-packages/depyf/explain/global_variables.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ from torch._inductor.codecache import PyCodeCache
6
+
7
+ data = {
8
+ "dump_src_dir": os.path.join(os.path.dirname(__file__), "dumped_src"),
9
+ "unpatched__exec_with_source": torch.fx.graph_module._exec_with_source,
10
+ "unpatched_load_by_key_path": PyCodeCache.load_by_key_path,
11
+ "unpatched___call__": torch._dynamo.eval_frame.OptimizeContext.__call__,
12
+ "optimized_functions": set(),
13
+ "is_inside_prepare_debug": False,
14
+ }
.venv/lib/python3.11/site-packages/depyf/explain/patched___call__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def patched___call__(self, code, check_fn):
2
+ from depyf.explain.global_variables import data
3
+ from depyf.utils import get_code_owner
4
+ import torch
5
+ unpatched___call__ = data["unpatched___call__"]
6
+ optimized_functions = data["optimized_functions"]
7
+ optimized_functions.add(code)
8
+
9
+ return unpatched___call__(self, code, check_fn)
.venv/lib/python3.11/site-packages/depyf/explain/patched__exec_with_source.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def patched__exec_with_source(src: str, globals, co_fields=None):
2
+ from depyf.explain.global_variables import data
3
+ from depyf.explain.utils import write_code_to_file_template
4
+ dump_src_dir = data["dump_src_dir"]
5
+ unpatched__exec_with_source = data["unpatched__exec_with_source"]
6
+ unpatched__exec_with_source(src, globals, co_fields)
7
+ import inspect
8
+ key = inspect.getsourcefile(globals["forward"])
9
+ import hashlib
10
+ import os
11
+ hash_value = hashlib.md5(src.encode()).hexdigest()
12
+ src = "# " + key + src
13
+ filename = write_code_to_file_template(
14
+ src,
15
+ f"{dump_src_dir}/fx_graph_code_" +
16
+ hash_value +
17
+ "_" +
18
+ "%s" +
19
+ ".py")
20
+ exec(compile(src, filename, "exec"), globals)
.venv/lib/python3.11/site-packages/depyf/explain/patched_boxed_run.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def patched_boxed_run(self, args_list):
2
+ return self.module.forward(*args_list)
.venv/lib/python3.11/site-packages/depyf/explain/patched_lazy_format_graph_code.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
2
+ from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template
3
+ from depyf.utils import get_code_owner
4
+ # When using torch export, the name includes
5
+ # a dumped dict of the nn_module_stack of a node in the module after the ':'
6
+ if ':' in name:
7
+ name = name.split(':')[0]
8
+ func_name = get_current_compiled_fn_name()
9
+ file_name = name if name != func_name else "Captured Graph"
10
+ file_name = file_name.replace(" ", "_")
11
+ file_name = func_name + "." + file_name
12
+ import inspect
13
+ import os
14
+
15
+ # https://github.com/pytorch/pytorch/pull/117911 introduces LazyGraphModule
16
+ # whose `forward` method is mangled and cannot be manipulated.
17
+ # We need to get rid of the laziness by calling `str` on it.
18
+ gm_s = str(gm)
19
+
20
+ fn = gm.forward
21
+
22
+ fn = get_code_owner(fn)
23
+
24
+ # update file path
25
+ filepath = inspect.getsourcefile(fn)
26
+ # try to use verbose code with type and shape annotations
27
+ use_gm = True
28
+
29
+ # use `print_readable` because it can include submodules
30
+ src = "from __future__ import annotations\nimport torch\n" + \
31
+ gm.print_readable(print_output=False)
32
+ src = src.replace("<lambda>", "GraphModule")
33
+ try:
34
+ compile(src, "noname", "exec")
35
+ except Exception as e:
36
+ # the pytorch version is before this PR: https://github.com/pytorch/pytorch/pull/113345
37
+ # Verbose code contains syntax error, it is recommended to use new
38
+ # version of PyTorch to get runnable code with shape and type
39
+ # annotations.
40
+ simple_code = gm._graph.python_code(root_module="self", verbose=False).src
41
+ commented_src = "\n# code below is commented out due to syntax error. You can refer to the code for shape and dtype annotation.\n"
42
+ commented_src += "".join(["# " + line +
43
+ "\n" for line in src.splitlines()])
44
+ src = simple_code + commented_src
45
+ use_gm = False
46
+ if filepath is not None:
47
+ new_filepath = write_code_to_file_template(
48
+ src, os.path.dirname(filepath) + "/" + file_name + "." + "%s" + ".py")
49
+ scope = fn.__globals__
50
+ exec(compile(src, filename=new_filepath, mode="exec"), scope)
51
+ if use_gm:
52
+ import torch
53
+ classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)]
54
+ assert len(classes) == 1
55
+ module_class = classes[0]
56
+ fn.__code__ = getattr(module_class, fn.__name__).__code__
57
+ else:
58
+ fn.__code__ = scope[fn.__name__].__code__
59
+ del scope[fn.__name__]
60
+
61
+ # =========================================
62
+ # original code of `lazy_format_graph_code`
63
+ def format_name():
64
+ if maybe_id is not None:
65
+ return f"{name} {maybe_id}"
66
+ else:
67
+ return name
68
+
69
+ if "print_output" not in kwargs:
70
+ kwargs["print_output"] = False
71
+
72
+ return LazyString(
73
+ lambda: _format_graph_code(
74
+ f"===== {format_name()} =====\n",
75
+ gm.forward.__code__.co_filename,
76
+ gm.print_readable(**kwargs),
77
+ )
78
+ )
.venv/lib/python3.11/site-packages/depyf/explain/patched_load_by_key_path.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def patched_load_by_key_path(
2
+ key: str,
3
+ path: str,
4
+ linemap,
5
+ attrs,
6
+ ):
7
+ from depyf.explain.global_variables import data
8
+ from depyf.explain.utils import write_code_to_file_template, get_current_compiled_fn_name
9
+ dump_src_dir = data["dump_src_dir"]
10
+ unpatched_load_by_key_path = data["unpatched_load_by_key_path"]
11
+ import os
12
+ # hack the path to our dump_src_dir
13
+ src = open(path).read()
14
+ # do not remove. remove in multi-processes will cause error.
15
+ # os.remove(path)
16
+
17
+ func_name = get_current_compiled_fn_name()
18
+ new_filepath = write_code_to_file_template(src, os.path.join(
19
+ dump_src_dir, func_name + ".kernel_" + "%s" + ".py"))
20
+ path = new_filepath
21
+ return unpatched_load_by_key_path(key, path, linemap, attrs)
.venv/lib/python3.11/site-packages/depyf/explain/utils.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._dynamo.eval_frame import innermost_fn
3
+ from torch._dynamo.eval_frame import _debug_get_cache_entry_list
4
+ import inspect
5
+
6
+ import dis
7
+ from types import CodeType
8
+ from typing import List, Callable, Dict, Union, Set
9
+ from dataclasses import dataclass
10
+ import contextlib
11
+
12
+ class CodeProxy:
13
+ instances: Dict[str, "CodeProxy"] = {}
14
+ used_instances: Set[str] = set()
15
+
16
+ @staticmethod
17
+ def get_new_name(name: str):
18
+ i = 0
19
+ new_name = name
20
+ if new_name.endswith(":"):
21
+ name = name[:-1]
22
+ while True:
23
+ new_name = f"{name}_{i}"
24
+ if new_name not in CodeProxy.instances:
25
+ break
26
+ i += 1
27
+ return new_name
28
+
29
+ @staticmethod
30
+ def consume_new_name(name: str):
31
+ new_name = CodeProxy.get_new_name(name)
32
+ CodeProxy.instances[new_name] = None
33
+ return new_name
34
+
35
+ @staticmethod
36
+ def decompile_with_name(code: CodeType, name: str, skip_decompile=False):
37
+ from depyf.utils import decompile_ensure
38
+ if hasattr(code, "__code__"):
39
+ code = code.__code__
40
+ if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"):
41
+ src = open(code.co_filename).read()
42
+ new_name = code.co_name
43
+ else:
44
+ new_name = CodeProxy.get_new_name(name)
45
+ if not skip_decompile:
46
+ src = decompile_ensure(code, new_name)
47
+ else:
48
+ src = ""
49
+ self = CodeProxy(src)
50
+ self.name = new_name
51
+ self.code = f"""<details>
52
+ <summary>{self.name}</summary>
53
+
54
+ ```python
55
+ {self.raw_code}
56
+ ```
57
+ </details>
58
+ """
59
+ CodeProxy.instances[self.name] = self
60
+ return self
61
+
62
+ def __init__(self, code: str):
63
+ # Don't directly use this constructor. Use decompile_with_name instead.
64
+ self.raw_code = "".join(
65
+ [" " + line + "\n" for line in code.splitlines() if line.strip() != ""])
66
+
67
+ def __str__(self):
68
+ CodeProxy.used_instances.add(self.name)
69
+ return self.name
70
+
71
+ @contextlib.contextmanager
72
+ @staticmethod
73
+ def record():
74
+ CodeProxy.used_instances = set()
75
+ yield CodeProxy.used_instances
76
+
77
+
78
+ @dataclass
79
+ class CacheResult:
80
+ original_code: CodeType
81
+ transformed_code: CodeType
82
+ guard: List[str]
83
+ compiled_subgraph: Callable
84
+ compiled_subgraph_proxy: CodeProxy
85
+ transformed_code_proxy: CodeProxy
86
+ referenced_global_functions: Dict[str, "DynamoOptimizationResult"]
87
+
88
+ def __init__(self, original_code, module, cache):
89
+ self.original_code = original_code
90
+
91
+ cpp_guard = False
92
+
93
+ # starting from https://github.com/pytorch/pytorch/pull/138896 ,
94
+ # pytorch uses `guard_manager` instead of `check_fn` to store the
95
+ # guards
96
+ attr_name = "guard_manager" if hasattr(cache, "guard_manager") else "check_fn"
97
+
98
+ guard_manager = getattr(cache, attr_name)
99
+
100
+ try:
101
+ klass = getattr(torch._dynamo.guards, "GuardManagerWrapper", None) or \
102
+ getattr(torch._dynamo.guards, "GuardManager", None) or \
103
+ getattr(torch._C._dynamo.guards, "GuardManager", None)
104
+ assert klass is not None
105
+ cpp_guard = isinstance(guard_manager, klass)
106
+ except Exception:
107
+ pass
108
+
109
+ if not cpp_guard:
110
+ # for old version of pytorch,
111
+ # `guard_manager` is a plain python function
112
+ guard_codes = guard_manager.code_parts
113
+ freevar_names = guard_manager.__code__.co_freevars
114
+ freevar_values = [x.cell_contents for x in guard_manager.__closure__]
115
+ else:
116
+ # keep the logic synced with
117
+ # https://github.com/pytorch/pytorch/blob/7b6b10417d8616ebd7a42b06528c5c2b2fded55a/torch/_dynamo/guards.py#L262
118
+ tensor_aliasing_guard_seen = False
119
+ def visit(root, ans):
120
+ nonlocal tensor_aliasing_guard_seen
121
+ for leaf_guard in root.get_leaf_guards():
122
+ if isinstance(leaf_guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):
123
+ if not tensor_aliasing_guard_seen:
124
+ tensor_aliasing_guard_seen = True
125
+ else:
126
+ continue
127
+ append_guard_code(leaf_guard, ans)
128
+ for child in root.get_child_managers():
129
+ visit(child, ans)
130
+ guard_codes = []
131
+ root = guard_manager.root
132
+
133
+ # Add guards in RootGuardManager
134
+ visit(root, guard_codes)
135
+ # Add guards in epilogue lambda guards
136
+ if hasattr(root, "get_epilogue_lambda_guards"):
137
+ for lambda_guard in root.get_epilogue_lambda_guards():
138
+ append_guard_code(lambda_guard, guard_codes)
139
+
140
+ if guard_manager.closure_vars is None:
141
+ freevar_names = tuple()
142
+ freevar_values = []
143
+ else:
144
+ freevar_names = tuple(guard_manager.closure_vars.keys())
145
+ freevar_values = list(guard_manager.closure_vars.values())
146
+
147
+ self.guard = guard_codes
148
+ self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)}
149
+ code = cache.code
150
+
151
+ compiled_subgraphs = [
152
+ name for name in code.co_names if name.startswith("__compiled")]
153
+ assert len(compiled_subgraphs) <= 1
154
+
155
+ if compiled_subgraphs:
156
+ # deal with compiled_subgraph
157
+ self.compiled_subgraph = innermost_fn(module[compiled_subgraphs[0]])
158
+ # subgraph does not need decompile
159
+ self.compiled_subgraph_proxy = CodeProxy.decompile_with_name(
160
+ self.compiled_subgraph, compiled_subgraphs[0], skip_decompile=True)
161
+ else:
162
+ self.compiled_subgraph = None
163
+ self.compiled_subgraph_proxy = None
164
+ # deal with transformed_code
165
+ self.transformed_code = code
166
+ self.transformed_code_proxy = CodeProxy.decompile_with_name(
167
+ self.transformed_code, "transformed_code:")
168
+ resume_fns = [
169
+ name for name in code.co_names if name.startswith("__resume")]
170
+ self.referenced_global_functions = {}
171
+ for name in resume_fns:
172
+ self.referenced_global_functions[name] = DynamoOptimizationResult(
173
+ original_code=module[name].__code__,
174
+ function_name=name,
175
+ module=module)
176
+
177
+ def to_data(self):
178
+ return {
179
+ "guard": self.guard,
180
+ "transformed_code": str(
181
+ self.transformed_code_proxy),
182
+ "compiled_subgraph": str(
183
+ self.compiled_subgraph_proxy) if self.compiled_subgraph_proxy is not None else '"No compiled subgraph."',
184
+ "referenced_global_functions": {
185
+ name: fn.to_data() for name,
186
+ fn in self.referenced_global_functions.items()}}
187
+
188
+
189
+ @dataclass
190
+ class DynamoOptimizationResult:
191
+ function_name: str
192
+ module: dict
193
+ original_code: CodeType
194
+ source_code_proxy: CodeProxy
195
+ transformed_code_entries: List[CacheResult]
196
+
197
+ def __init__(self, original_code, function_name=None, module=None):
198
+ self.original_code = original_code
199
+ if function_name is None:
200
+ self.function_name = original_code.co_name
201
+ else:
202
+ self.function_name = function_name
203
+ self.module = module
204
+ caches = _debug_get_cache_entry_list(original_code)
205
+ self.transformed_code_entries = [
206
+ CacheResult(original_code, module, cache) for cache in caches]
207
+ self.source_code_proxy = CodeProxy.decompile_with_name(
208
+ self.original_code, self.function_name)
209
+
210
+ def to_data(self):
211
+ data = {
212
+ "function_name": self.function_name,
213
+ "source_code": str(
214
+ self.source_code_proxy),
215
+ "transformed_code_entries": [
216
+ entry.to_data() for entry in self.transformed_code_entries]}
217
+ return data
218
+
219
+ def to_src(self):
220
+ raw_code = self.source_code_proxy.raw_code
221
+
222
+ # prepare function signature, from `def toy_example(a, b)` to `def
223
+ # transformed_toy_example(a, b)`
224
+ signature = raw_code.splitlines()[0].replace(
225
+ "def ", "def transformed_", 1)
226
+ code = signature.strip()
227
+
228
+ # prepare args for guards, like `L = {"a": a, "b": b}`
229
+ code_obj = self.original_code
230
+ normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount
231
+ arg_names = code_obj.co_varnames[:normal_arg_count]
232
+ arg_dict = "__local_dict = {" + \
233
+ ", ".join([f'"{name}": {name}' for name in arg_names]) + "}"
234
+ code += "\n" + " " * 4 + arg_dict
235
+ code += "\n" + " " * 4 + "__global_dict = globals()"
236
+
237
+ additional_code = ""
238
+
239
+ for entry in self.transformed_code_entries:
240
+
241
+ # prepare guards, like `def guard_0(L):\n return a > 0 and b >
242
+ # 0`
243
+ freevars = "".join([f"{name} = '''{value}'''\n" for name, value in entry.freevars.items() if name not in ["__builtins__"]])
244
+ if freevars:
245
+ freevars = "# Note: the following variables are used inside the guard function.\n" + freevars
246
+ guard_lines = [" " * 4 + "__guard_hit = True\n"]
247
+ for x in entry.guard:
248
+ guard_lines.append(" " * 4 + f"__guard_hit = __guard_hit and {x}\n")
249
+ guard_lines.append(" " * 4 + "return __guard_hit\n")
250
+ guard = "".join(guard_lines)
251
+ if entry.transformed_code_proxy.name.startswith("__transformed_code_"):
252
+ guard_func_name = entry.transformed_code_proxy.name.replace("__transformed_code_", "__guard_")
253
+ else:
254
+ guard_func_name = CodeProxy.consume_new_name("guard:")
255
+ additional_code += "\n" + freevars + f"def {guard_func_name}(L, G, **___kwargs_ignored):\n" + guard
256
+
257
+ if entry.compiled_subgraph_proxy is not None:
258
+ # prepare compiled subgraph, like `__compiled_fn_0`
259
+ subgraph_name = entry.compiled_subgraph_proxy.name
260
+ additional_code += "\n"
261
+ additional_code += f"# Note: please refer to the graph code in {subgraph_name}*.py.\n"
262
+ additional_code += f"# Captured Graph: Dynamo generated graph (debuggable when using eager backend).\n"
263
+ additional_code += f"# Joint graph: joint forward+backward graph from aot autograd.\n"
264
+ additional_code += f"# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).\n"
265
+ additional_code += f"# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).\n"
266
+ additional_code += f"# AFTER XXX: graph processed by inductor (not debuggable).\n"
267
+ additional_code += f"def {subgraph_name}(*args, **kwargs):\n pass\n"
268
+
269
+ # prepare transformed code, like `transformed_code_0`
270
+ additional_code += "\n" + \
271
+ remove_indentation(entry.transformed_code_proxy.raw_code) + "\n"
272
+
273
+ for name, func in entry.referenced_global_functions.items():
274
+ additional_code = func.to_src() + additional_code
275
+
276
+ code += "\n" + " " * 4 + \
277
+ f"if {guard_func_name}(__local_dict, __global_dict):\n" + " " * 8 + f"return {entry.transformed_code_proxy.name}({', '.join(arg_names)})"
278
+
279
+ additional_code += "\n" + "# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.\n" + \
280
+ remove_indentation(self.source_code_proxy.raw_code) + "\n"
281
+
282
+ code += "\n" + " " * 4 + "# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.\n" + \
283
+ " " * 4 + f"return {self.source_code_proxy.name}({', '.join(arg_names)})"
284
+ return additional_code + code + \
285
+ f"\n\n#============ end of {self.function_name} ============#\n"
286
+
287
+
288
+ def remove_indentation(code: str):
289
+ lines = code.splitlines()
290
+ indent = len(lines[0]) - len(lines[0].lstrip())
291
+ return "".join([line[indent:] + "\n" for line in lines])
292
+
293
+ def append_guard_code(guard, ans):
294
+ for verbose_str in guard.verbose_code_parts():
295
+ verbose_str = verbose_str.strip()
296
+ ans.append(verbose_str)
297
+
298
+ from contextlib import contextmanager
299
+
300
+ @contextmanager
301
+ def lock_on_file(path_template):
302
+ lock_path = path_template + ".lock"
303
+ from filelock import FileLock
304
+ import os
305
+ lock = FileLock(lock_path)
306
+ try:
307
+ with lock:
308
+ yield
309
+ finally:
310
+ pass
311
+
312
+
313
+ def write_code_to_file_template(src, path_template):
314
+ with lock_on_file(path_template):
315
+ import os
316
+ count = 0
317
+ while True:
318
+ new_filepath = path_template % str(count)
319
+ if not os.path.exists(new_filepath):
320
+ with open(new_filepath, "w") as f:
321
+ f.write(src)
322
+ break
323
+ # might be a hash collision
324
+ existing_code = open(new_filepath).read()
325
+ if existing_code == src:
326
+ break
327
+ count += 1
328
+ return new_filepath
329
+
330
+
331
+ def get_current_compiled_fn_name():
332
+ import torch
333
+ from torch._dynamo.bytecode_transformation import _unique_id_counter
334
+ from copy import copy
335
+ # torch.compile already called the next, we should add minus 1 to get the
336
+ # correct name
337
+ current_count = next(copy(_unique_id_counter)) - 1
338
+ return "__compiled_fn_" + str(current_count)
.venv/lib/python3.11/site-packages/depyf/optimization.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from abc import abstractmethod
4
+ from contextlib import contextmanager
5
+ from types import CodeType
6
+ from typing import Callable, List
7
+
8
+ import torch
9
+
10
+
11
+ class TorchCompileWrapperWithCustomDispatcher:
12
+ """
13
+ A wrapper class for torch.compile, with a custom dispatch logic.
14
+ Subclasses should:
15
+ 1. Implement the forward method
16
+ 2. Implement the dispatch logic in the __call__ method
17
+ It can use `self.compiled_codes` to access the compiled bytecode,
18
+ and `with self.dispatch_to_code(index):` to dispatch to
19
+ the compiled code.
20
+ 3. Implement the `__init__` method to determine how to call
21
+ `torch.compile` over the forward method.
22
+ """
23
+
24
+ def __init__(self, compiled_callable: Callable, use_custom_dispatcher: bool = True):
25
+ self.compiled_callable = compiled_callable
26
+ self.original_code_object = self.__class__.forward.__code__
27
+ self.compiled_codes: List[CodeType] = []
28
+ torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
29
+
30
+ self.use_custom_dispatcher: bool = use_custom_dispatcher
31
+
32
+ def __call__(self, *args, **kwargs):
33
+ """Implement the dispatch logic here, beyond the torch.compile level.
34
+ NOTE: this function can have additional arguments beyond the forward
35
+ method, for directly dispatching to the compiled code.
36
+ """
37
+ return self.compiled_callable(*args, **kwargs)
38
+
39
+ @abstractmethod
40
+ def forward(self, *args, **kwargs):
41
+ ...
42
+
43
+ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
44
+ """Hook to save the compiled bytecode for direct execution."""
45
+ if old_code is not self.original_code_object:
46
+ return
47
+ frame = sys._getframe()
48
+ while True:
49
+ frame = frame.f_back
50
+ code_name = frame.f_code.co_name
51
+ file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
52
+ if code_name == "_compile" and file_name == "convert_frame.py":
53
+ break
54
+ frame = frame.f_locals["frame"]
55
+ assert frame.f_code == old_code
56
+
57
+ if frame.f_locals["self"] is not self:
58
+ return
59
+
60
+ self.compiled_codes.append(new_code)
61
+
62
+ @contextmanager
63
+ def dispatch_to_code(self, index: int):
64
+ """Context manager to dispatch to the compiled code.
65
+ Why does this work? Because Dynamo guarantees that the compiled
66
+ bytecode has exactly the same arguments, cell variables, and free
67
+ variables as the original code. Therefore we can directly switch
68
+ the code object in the function and call it.
69
+
70
+ See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
71
+ """ # noqa
72
+ self.__class__.forward.__code__ = self.compiled_codes[index]
73
+ yield
74
+ self.__class__.forward.__code__ = self.original_code_object
.venv/lib/python3.11/site-packages/depyf/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dis
2
+ from typing import List, Tuple, Union, Optional, Callable, Any, Dict, Set
3
+ from types import CodeType
4
+
5
+
6
+ def get_function_signature(code_obj: CodeType,
7
+ overwite_fn_name: Optional[str] = None) -> str:
8
+ # Extract all required details from the code object
9
+ # Sometimes the code object does not have a name, e.g. when it is a lambda
10
+ # function, so we can overwrite it to be a valid name
11
+ normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount
12
+ arg_names = code_obj.co_varnames[:normal_arg_count]
13
+ arg_names = [
14
+ x if not x.startswith(".") else x.replace(
15
+ ".", "comp_arg_") for x in arg_names]
16
+
17
+ import inspect
18
+ if code_obj.co_flags & inspect.CO_VARARGS:
19
+ arg_names.append('*' + code_obj.co_varnames[normal_arg_count])
20
+ normal_arg_count += 1
21
+ if code_obj.co_flags & inspect.CO_VARKEYWORDS:
22
+ arg_names.append('**' + code_obj.co_varnames[normal_arg_count])
23
+ normal_arg_count += 1
24
+ args_str = ', '.join(arg_names)
25
+ fn_name = overwite_fn_name if overwite_fn_name is not None else code_obj.co_name
26
+ header = f"def {fn_name}({args_str}):\n"
27
+ return header
28
+
29
+
30
+ def collect_all_code_objects(code: CodeType) -> List[CodeType]:
31
+ code_objects = [code]
32
+ for const in code.co_consts:
33
+ if isinstance(const, type(code)):
34
+ code_objects.extend(collect_all_code_objects(const))
35
+ return code_objects
36
+
37
+
38
+ def safe_create_directory(path):
39
+ # allow multiple processes to create the same directory
40
+ import os
41
+ try:
42
+ os.makedirs(path, exist_ok=True)
43
+ except OSError as e:
44
+ if not os.path.isdir(path):
45
+ raise
46
+
47
+
48
+
49
+ def get_code_owner(fn):
50
+ """A callable object `fn` might have a __code__ attribute, which is a code object.
51
+ However, `fn` might not be the owner of the code object. Only the code owner can change the code object.
52
+ This function returns the owner of the code object.
53
+ An example:
54
+ class A:
55
+ def func(self):
56
+ return 1
57
+ a = A()
58
+ `a.func.__code__` is read-only. `A.func.__code__` is writable.
59
+ We can change the code object via `a.func.__func__.__code__`.
60
+ """
61
+ import functools
62
+ while True:
63
+ if hasattr(fn, "__func__"):
64
+ # deal with bounded function
65
+ fn = fn.__func__
66
+ elif hasattr(fn, "__wrapped__"):
67
+ # deal with lru_cache or other decorators
68
+ fn = fn.__wrapped__
69
+ elif isinstance(fn, functools.partial):
70
+ # deal with partial function
71
+ fn = fn.func
72
+ elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"):
73
+ # deal with callable object
74
+ fn = fn.__call__.__func__
75
+ else:
76
+ break
77
+ return fn
78
+
79
+
80
+
81
+ def decompile_ensure(fn: CodeType, overwite_fn_name=None):
82
+ import depyf
83
+ from depyf.decompiler import DecompilationError
84
+ try:
85
+ decompiled_source_code = depyf.Decompiler(
86
+ fn).decompile(overwite_fn_name=overwite_fn_name)
87
+ except DecompilationError as e:
88
+ header = get_function_signature(fn, overwite_fn_name=overwite_fn_name)
89
+ decompiled_source_code = header + " 'Failed to decompile.'\n"
90
+ return decompiled_source_code
.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/LICENSE ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2014-2022 Matthew Brennan Jones <matthew.brennan.jones@gmail.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
6
+ this software and associated documentation files (the "Software"), to deal in
7
+ the Software without restriction, including without limitation the rights to
8
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9
+ the Software, and to permit persons to whom the Software is furnished to do so,
10
+ subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/METADATA ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: py-cpuinfo
3
+ Version: 9.0.0
4
+ Summary: Get CPU info with pure Python
5
+ Home-page: https://github.com/workhorsy/py-cpuinfo
6
+ Author: Matthew Brennan Jones
7
+ Author-email: matthew.brennan.jones@gmail.com
8
+ License: MIT
9
+ Platform: UNKNOWN
10
+ Classifier: Development Status :: 5 - Production/Stable
11
+ Classifier: Topic :: Utilities
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ License-File: LICENSE
15
+
16
+ py-cpuinfo
17
+ ==========
18
+
19
+
20
+ Py-cpuinfo gets CPU info with pure Python. Py-cpuinfo should work
21
+ without any extra programs or libraries, beyond what your OS provides.
22
+ It does not require any compilation(C/C++, assembly, et cetera) to use.
23
+ It works with Python 3.
24
+
25
+ Documentation can be viewed here: https://github.com/workhorsy/py-cpuinfo
26
+
27
+
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/RECORD ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ../../../bin/cpuinfo,sha256=UfxJjvVjzhK2GanXPckc2fcNn-4NDzyKR188s6H-PNo,224
2
+ cpuinfo/__init__.py,sha256=T6gndqGAggfJCu4_iOziTnomCN7KzaAK_OYTewE4FMA,44
3
+ cpuinfo/__main__.py,sha256=nSxC6Hqhi-0lN7Z4WwtKdxQdf3cUJefb5hOahCzh4Yg,33
4
+ cpuinfo/__pycache__/__init__.cpython-311.pyc,,
5
+ cpuinfo/__pycache__/__main__.cpython-311.pyc,,
6
+ cpuinfo/__pycache__/cpuinfo.cpython-311.pyc,,
7
+ cpuinfo/cpuinfo.py,sha256=HHyDlDUNovE3QzJ3hviiM1ngyOC4iD7i6oGiz2iTmVk,84388
8
+ py_cpuinfo-9.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
9
+ py_cpuinfo-9.0.0.dist-info/LICENSE,sha256=3br3Y5a_XHqkWXWiHq_i4i7st9paoNt8sOYVL6r-800,1127
10
+ py_cpuinfo-9.0.0.dist-info/METADATA,sha256=rRFelvhFdoYcXnXXYDAbgdIxQ8_iVUa5lUHgEmU3ncE,794
11
+ py_cpuinfo-9.0.0.dist-info/RECORD,,
12
+ py_cpuinfo-9.0.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
13
+ py_cpuinfo-9.0.0.dist-info/entry_points.txt,sha256=ZwrsclY_xUA0xJZK98bLxBdcowxnkK0ANYUT4FYcZJ8,42
14
+ py_cpuinfo-9.0.0.dist-info/top_level.txt,sha256=XsjpunhkxD4hvznqQjrFNw0rtgizHEOGzewPZY3UEtU,8
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.37.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/entry_points.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [console_scripts]
2
+ cpuinfo = cpuinfo:main
3
+
.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ cpuinfo
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/layers.cpython-311.pyc ADDED
Binary file (1.03 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/models.cpython-311.pyc ADDED
Binary file (6.17 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/request.cpython-311.pyc ADDED
Binary file (1.72 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/worker_manager.cpython-311.pyc ADDED
Binary file (2.4 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/adapter_commons/layers.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Tuple
5
+
6
+
7
+ @dataclass
8
+ class AdapterMapping:
9
+ # Per every token in input_ids:
10
+ index_mapping: Tuple[int, ...]
11
+ # Per sampled token:
12
+ prompt_mapping: Tuple[int, ...]
13
+
14
+ def __post_init__(self):
15
+ self.index_mapping = tuple(self.index_mapping)
16
+ self.prompt_mapping = tuple(self.prompt_mapping)
.venv/lib/python3.11/site-packages/vllm/adapter_commons/models.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Callable, Dict, Optional, TypeVar
5
+
6
+ from torch import nn
7
+
8
+ from vllm.logger import init_logger
9
+ from vllm.utils import LRUCache
10
+
11
+ logger = init_logger(__name__)
12
+
13
+
14
+ class AdapterModel(ABC):
15
+
16
+ def __init__(self, model_id=None):
17
+ self.id = model_id
18
+
19
+ @abstractmethod
20
+ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
21
+ # Common initialization code
22
+ # Load weights or embeddings from local checkpoint
23
+ raise NotImplementedError("Subclasses must implement this method.")
24
+
25
+
26
+ T = TypeVar('T')
27
+
28
+
29
+ class AdapterLRUCache(LRUCache[int, T]):
30
+
31
+ def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
32
+ super().__init__(capacity)
33
+ self.deactivate_fn = deactivate_fn
34
+
35
+ def _on_remove(self, key: int, value: Optional[T]):
36
+ logger.debug("Removing adapter int id: %d", key)
37
+ self.deactivate_fn(key)
38
+ return super()._on_remove(key, value)
39
+
40
+
41
+ class AdapterModelManager(ABC):
42
+
43
+ def __init__(
44
+ self,
45
+ model: nn.Module,
46
+ ):
47
+ """Create a AdapterModelManager and adapter for a given model.
48
+ Args:
49
+ model: the model to be adapted.
50
+ """
51
+ self.model: nn.Module = model
52
+ self._registered_adapters: Dict[int, Any] = {}
53
+ # Dict instead of a Set for compatibility with LRUCache.
54
+ self._active_adapters: Dict[int, None] = {}
55
+ self.adapter_type = 'Adapter'
56
+ self._last_mapping = None
57
+
58
+ def __len__(self) -> int:
59
+ return len(self._registered_adapters)
60
+
61
+ @property
62
+ @abstractmethod
63
+ def adapter_slots(self) -> int:
64
+ raise NotImplementedError
65
+
66
+ @property
67
+ @abstractmethod
68
+ def capacity(self) -> int:
69
+ raise NotImplementedError
70
+
71
+ @abstractmethod
72
+ def activate_adapter(self, adapter_id: int) -> bool:
73
+ raise NotImplementedError
74
+
75
+ @abstractmethod
76
+ def deactivate_adapter(self, adapter_id: int) -> bool:
77
+ raise NotImplementedError
78
+
79
+ @abstractmethod
80
+ def add_adapter(self, adapter: Any) -> bool:
81
+ raise NotImplementedError
82
+
83
+ @abstractmethod
84
+ def set_adapter_mapping(self, mapping: Any) -> None:
85
+ raise NotImplementedError
86
+
87
+ @abstractmethod
88
+ def remove_adapter(self, adapter_id: int) -> bool:
89
+ raise NotImplementedError
90
+
91
+ @abstractmethod
92
+ def remove_all_adapters(self) -> None:
93
+ raise NotImplementedError
94
+
95
+ @abstractmethod
96
+ def get_adapter(self, adapter_id: int) -> Optional[Any]:
97
+ raise NotImplementedError
98
+
99
+ @abstractmethod
100
+ def list_adapters(self) -> Dict[int, Any]:
101
+ raise NotImplementedError
102
+
103
+ @abstractmethod
104
+ def pin_adapter(self, adapter_id: int) -> bool:
105
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/adapter_commons/request.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+
6
+ class AdapterRequest(ABC):
7
+ """
8
+ Base class for adapter requests.
9
+ """
10
+
11
+ @property
12
+ @abstractmethod
13
+ def adapter_id(self) -> int:
14
+ raise NotImplementedError
15
+
16
+ def __post_init__(self) -> None:
17
+ if self.adapter_id < 1:
18
+ raise ValueError(f"id must be > 0, got {self.adapter_id}")
19
+
20
+ def __eq__(self, value: object) -> bool:
21
+ return isinstance(
22
+ value, self.__class__) and self.adapter_id == value.adapter_id
23
+
24
+ def __hash__(self) -> int:
25
+ return hash(self.adapter_id)