diff --git a/.venv/lib/python3.11/site-packages/depyf/VERSION.txt b/.venv/lib/python3.11/site-packages/depyf/VERSION.txt new file mode 100644 index 0000000000000000000000000000000000000000..47d04a528837ea50434734bd7cca947d47c4e012 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/VERSION.txt @@ -0,0 +1 @@ +0.18.0 \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/depyf/__init__.py b/.venv/lib/python3.11/site-packages/depyf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e06eda61bfccfdcd99446384589d52c848fb0702 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/__init__.py @@ -0,0 +1,24 @@ +from types import CodeType +import warnings + +from .decompiler import Decompiler, decompile + +try: + import torch + torch_version = torch.__version__ + valid = ("dev" not in torch_version and torch_version >= "2.2") or ( + "dev" in torch_version and torch_version.split("dev")[-1] >= "20231020") + if not valid: + warnings.warn( + ("Please use the nightly version of PyTorch to enable bytecode hooks.\n" + "PyTorch nightly can be installed by: `conda install pytorch-nightly::pytorch torchvision torchaudio -c pytorch-nightly`")) + + from depyf.explain.enhance_logging import install, uninstall + from depyf.explain.enable_debugging import prepare_debug, debug +except ImportError as e: + # print(e) + pass + +import os + +__version__ = open(f"{os.path.dirname(__file__)}/VERSION.txt").read().strip() diff --git a/.venv/lib/python3.11/site-packages/depyf/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..172a75783366c758aacfc4c1764cb05fba1b779c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/__pycache__/code_transform.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/__pycache__/code_transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..993c8de3b1edcd4fd2403fbe7715a1b3b0a52629 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/__pycache__/code_transform.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/__pycache__/decompiler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/__pycache__/decompiler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfdeb8d258d6d126dc9cf6ce28280217ff8a7c9d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/__pycache__/decompiler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/__pycache__/optimization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/__pycache__/optimization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1334498cdafbd04c8d5e47f79f4324354c86484 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/__pycache__/optimization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7650a5b5f69e34fa35592f6d20551a79925ce223 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/code_transform.py b/.venv/lib/python3.11/site-packages/depyf/code_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..20985aec798686b54b0e7ab281e69a87da2273d8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/code_transform.py @@ -0,0 +1,475 @@ +import dis +from typing import List, Tuple, Union, Optional, Callable, Any, Dict, Set +from types import CodeType +import ast +import astor +from collections import defaultdict +import dataclasses +import sys +import hashlib + +py311 = sys.version_info >= (3, 11) +all_jump_opcode_set = set(dis.hasjabs) | set(dis.hasjrel) + + +@dataclasses.dataclass +class Instruction: + """A mutable version of dis.Instruction""" + + opcode: int + opname: str + arg: Optional[int] + argval: Any + argrepr: str + offset: Optional[int] = None + starts_line: Optional[int] = None + is_jump_target: bool = False + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return id(self) == id(other) + + def short_inst_repr(self): + return f"Instruction(opname={self.opname}, offset={self.offset})" + + def is_jump(self): + return self.opcode in all_jump_opcode_set + + def get_jump_target(self: "Instruction"): + if self.is_jump() and "to " in self.argrepr: + return int(self.argrepr.replace("to ", "").strip()) + # seems like a bug, "FOR_ITER" is in `dis.hasjrel`, but its `argval` is + # an absolute offset + if self.opcode in dis.hasjabs: + return self.argval + elif self.opcode in dis.hasjrel: + return self.offset + self.argval if not py311 else self.argval + else: + raise ValueError( + f"Instruction {self.opname} does not have jump target") + + +def convert_instruction(i: dis.Instruction) -> Instruction: + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.argrepr, + i.offset, + i.starts_line, + i.is_jump_target, + ) + + +def nop_instruction(inst: Instruction): + """Inplace modify an instruction as nop.""" + inst.opname = "NOP" + inst.opcode = dis.opmap["NOP"] + inst.arg = 0 + inst.argval = 0 + inst.argrepr = "" + inst.offset + inst.starts_line + inst.is_jump_target = False + return inst + + +def propagate_line_nums(instructions: List[Instruction]): + """Ensure every instruction has line number set in case some are removed""" + cur_line_no = None + + def populate_line_num(inst): + nonlocal cur_line_no + if inst.starts_line: + cur_line_no = inst.starts_line + + inst.starts_line = cur_line_no + + for inst in instructions: + populate_line_num(inst) + + +# ======= begin code borrowed from pytorch/torch/_dynamo/bytecode_transformation.py =========== +@dataclasses.dataclass +class ExceptionTableEntry: + start: int + end: int + target: int + depth: int + lasti: bool + +def decode_exception_table_varint(bytes_iter) -> int: + """ + Inverse of `encode_exception_table_varint`. + """ + b = next(bytes_iter) + val = b & 63 + while b & 64: + val <<= 6 + b = next(bytes_iter) + val |= b & 63 + return val + +def check_exception_table(tab: List[ExceptionTableEntry]) -> None: + """ + Verifies that a list of ExceptionTableEntries will make a well-formed + jump table: entries are non-empty, sorted, and do not overlap. + """ + for i in range(len(tab) - 1): + assert ( + tab[i].start <= tab[i].end + and tab[i].end < tab[i + 1].start + and tab[i + 1].start <= tab[i + 1].end + ) + +def parse_exception_table(exntab) -> List[ExceptionTableEntry]: + """ + Parse the exception table according to + https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + exntab_iter = iter(exntab) + tab = [] + try: + while True: + start = decode_exception_table_varint(exntab_iter) * 2 + length = decode_exception_table_varint(exntab_iter) * 2 + end = start + length - 2 + target = decode_exception_table_varint(exntab_iter) * 2 + dl = decode_exception_table_varint(exntab_iter) + depth = dl >> 1 + lasti = bool(dl & 1) + tab.append(ExceptionTableEntry(start, end, target, depth, lasti)) + except StopIteration: + check_exception_table(tab) + return tab +# ======= end code borrowed from pytorch/torch/_dynamo/bytecode_transformation.py =========== + +def simplify_finally_statement(instructions: List[Instruction]): + """Simplify finally statement. + 3.10 finally statement: + SETUP_FINALLY + body + POP_BLOCK + finally code + Exception code + RERAISE + """ + for i, inst in enumerate(instructions): + if inst.opname == "SETUP_FINALLY": + finally_target = inst.get_jump_target() + reraise_idx = [j for j, _inst in enumerate( + instructions) if _inst.offset >= finally_target and _inst.opname == "RERAISE"] + if reraise_idx: + reraise_index = reraise_idx[0] + for j, _inst in enumerate(instructions): + if _inst.offset >= finally_target and j <= reraise_index: + nop_instruction(_inst) + + +def nop_unreachable_bytecode(code, + instructions: List[dis.Instruction]) -> List[dis.Instruction]: + """Mark unreachable bytecode as NOP.""" + jumps = set(dis.hasjabs) | set(dis.hasjrel) + + exception_targets = {} + if py311: + tab = parse_exception_table(code.co_exceptiontable) + exception_targets = {entry.target: entry for entry in tab} + + # difference bwteween `i in deadcode_positions` and `reachable[i] == False`: + # `i in deadcode_positions` means that the instruction is not reachable, defnitely a NOP + # `reachable[i] == False` means that the instruction is not reachable currently, but it might be reachable later when we iterate through the instructions + reachable = [False for x in instructions] + deadcode_positions = set() + reachable[0] = True + # each instruction marks the instruction after it + for i, inst in enumerate(instructions): + if inst.is_jump_target or inst.offset in exception_targets: + # the instruction is the target of a jump + reachable[i] = True + # the last instruction does not need to mark any following instructions + if i == len(instructions) - 1: + break + # this instruction is not reachable, nothing to do + if not reachable[i]: + continue + # this instruction is reachable + # the following instruction is reachable if it is sequential op or + # conditional jump + if inst.opname in ["RETURN_VALUE", "BREAK_LOOP"]: + # the instruction after the return is unreachable + pass + elif inst.opcode in jumps: + if inst.opcode in dis.hasjrel and inst.get_jump_target() == inst.offset: + # this is a jump to itself, it is regarded as a NOP, per the documentation at + # https://devguide.python.org/internals/interpreter/#jumps + reachable[i] = False + reachable[i + 1] = True + continue + if "IF" in inst.opname or "FOR_ITER" in inst.opname or "SETUP_LOOP" in inst.opname: + # the fallback block is always reachable for conditional jumps + reachable[i + 1] = True + elif inst.opname in ["SETUP_FINALLY", "SETUP_WITH", "BEFORE_WITH"]: + # the with/finally block is always reachable + reachable[i + 1] = True + else: + # this is a direct jump, the target is reachable + # we further check if any outside instructions jump into in-between instructions + # if not, we can mark this instruction as unreachable, too + # later, in-between instructions will be marked as unreachable (NOP) + # and the interpreter will slide through all the NOP directly + # to the target + jump_forwards = [j for j, instruct in enumerate( + instructions) if instruct.offset >= inst.get_jump_target()] + if len(jump_forwards): + j = jump_forwards[0] + if j > i: + smallest_jump_in = j + has_jump_in = False + + for ii, inst_ii in enumerate(instructions[i: j]): + # in python 3.11 exception table + # exception target indicates a jump target from many instructions + # and therefore it is treated as a jump-in + if inst_ii.offset in exception_targets: + has_jump_in = True + smallest_jump_in = min( + smallest_jump_in, ii) + + for ii, inst_ii in enumerate(instructions): + try: + jump_location = inst_ii.get_jump_target() + if (ii < i or ii > j) and (jump_location >= inst.offset and jump_location < instructions[j].offset): + has_jump_in = True + smallest_jump_in = min( + smallest_jump_in, ii) + except Exception: + pass + if not has_jump_in: + reachable[i] = False + for _ in range(i, smallest_jump_in): + deadcode_positions.add(_) + else: + reachable[i + 1] = True + + for i in deadcode_positions: + reachable[i] = False + + # mark unreachable instructions as NOP + for inst, flag in zip(instructions, reachable): + if not flag: + nop_instruction(inst) + + +def add_indentation(code: str, indentation: int = 4) -> str: + """Add indentation to code.""" + return "".join( + " " * + indentation + + line + + "\n" for line in code.splitlines()) + + +def remove_indentation(code: str, indentation: int = 4) -> str: + """Remove indentation from code.""" + return "".join(line[indentation:] + "\n" for line in code.splitlines()) + + +class RemoveAssignmentTransformer(ast.NodeTransformer): + def __init__(self, + temp_name: str, + temp_occurrences: Dict[str, + List[ast.Name]]): + # optimize one temp_name at a time + self.temp_name = temp_name + self.temp_occurrences = temp_occurrences + + def visit_Assign(self, node): + # single assimngment like `temp = xxx` + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + name = node.targets[0].id + # the assignment is like `temp = xxx` + if name == self.temp_name: + if len(self.temp_occurrences[name]) == 1: + return ast.Expr(value=node.value) + elif len(self.temp_occurrences[name]) == 3 and isinstance(self.temp_occurrences[name][-1], bool): + # we save the `xxx` here + self.temp_occurrences[name].append(node.value) + if self.temp_occurrences[name][-2]: + return None + return node + + +class RemoveAssignment2Transformer(ast.NodeTransformer): + def __init__(self, + temp_name: str, + temp_occurrences: Dict[str, + List[ast.Name]]): + # optimize one temp_name at a time + self.temp_name = temp_name + self.temp_occurrences = temp_occurrences + + def visit_Name(self, node): + name = node.id + if name == self.temp_name and len(self.temp_occurrences[name]) == 4 and isinstance( + self.temp_occurrences[name][-2], bool): + if self.temp_occurrences[name][-2]: + return self.temp_occurrences[name][-1] + return node + + +def get_parents(node): + """Collect all parent nodes of a given node.""" + parents = [] + while node: + parents.append(node) + node = getattr(node, "parent", None) + return parents + + +def set_parents(node, parent=None): + """Recursively set the parent attribute for each node.""" + for child in ast.iter_child_nodes(node): + child.parent = parent + set_parents(child, child) + + +def lowest_common_parent(node1, node2): + """Get the lowest common parent for two nodes.""" + parents1 = get_parents(node1) + parents2 = get_parents(node2) + + # Reverse the parents list to start comparing from the root. + parents1.reverse() + parents2.reverse() + + last_common = None + for p1, p2 in zip(parents1, parents2): + if p1 is p2: + last_common = p1 + else: + break + return last_common, p1, p2 + + +def remove_some_temp( + source_code: str, + temp_prefix: str, + indentation: int = 4) -> str: + tree = ast.parse(source_code) + set_parents(tree) + + temp_occurrences = defaultdict(list) + for node in ast.walk(tree): + if isinstance(node, ast.Name) and node.id.startswith(temp_prefix): + temp_occurrences[node.id].append(node) + + for key in temp_occurrences: + if len(temp_occurrences[key]) == 2: + node1 = temp_occurrences[key][0] + node2 = temp_occurrences[key][1] + parent, parent1, parent2 = lowest_common_parent(node1, node2) + assignment_node = node1 if isinstance( + node1.parent, ast.Assign) else node2 + assignment_parent = parent1 if isinstance( + node1.parent, ast.Assign) else parent2 + indentation_nodes = ( + ast.FunctionDef, + ast.AsyncFunctionDef, + ast.For, + ast.AsyncFor, + ast.While, + ast.If, + ast.Try, + ast.With, + ast.AsyncWith, + ast.ClassDef) + # we cannot remove the assignment if the assignment `temp=xxx` is + # in an indentation block while the usage of `temp` is not + can_merge = not isinstance(assignment_parent, indentation_nodes) + temp_occurrences[key].append(can_merge) + tree = RemoveAssignmentTransformer(key, temp_occurrences).visit(tree) + tree = RemoveAssignment2Transformer(key, temp_occurrences).visit(tree) + + reconstructed_code = astor.to_source(tree, indent_with=" " * indentation) + return reconstructed_code + + +class IdentifierReplacer(ast.NodeTransformer): + + # def visit_Name(self, node): + # return ast.copy_location(ast.Name(id='PLACEHOLDER', ctx=node.ctx), node) + + def visit_FunctionDef(self, node): + node.name = 'PLACEHOLDER' + return self.generic_visit(node) + + # def visit_AsyncFunctionDef(self, node): + # node.name = 'PLACEHOLDER' + # return self.generic_visit(node) + + # def visit_ClassDef(self, node): + # node.name = 'PLACEHOLDER' + # return self.generic_visit(node) + + # def visit_Attribute(self, node): + # node.attr = 'PLACEHOLDER' + # return self.generic_visit(node) + + +def fix_irregular_code( + old_bytecode: CodeType, + src_code: str, + add_local_variables: Optional[List[str]]=None, + add_cellvars: Optional[List[str]]=None, + ) -> str: + function_name = src_code.split("(")[0].split()[-1] + new_code = src_code + if add_local_variables is not None or add_cellvars is not None: + lines = src_code.splitlines() + header = lines[0] + body = lines[1:] + headers = [header] + if add_local_variables: + added_line = "; ".join(f"{x} = None" for x in add_local_variables) + added_line = " " + added_line + " # this line helps Python to generate bytecode with at least the same number of local variables as the original function\n" + headers.append(added_line) + if add_cellvars: + added_line = "return " + ", ".join(x for x in add_cellvars) + added_line = ( + " def __helper_for_cellvars():\n" + " # this function helps Python to generate bytecode with at least the same number of cellvars as the original function\n" + ) + " " + added_line + headers.append(added_line) + new_code = "".join([x + "\n" for x in headers + body]) + + freevars = old_bytecode.co_freevars + if freevars: + tmp_code = ( + "def __helper_outer_function():\n" + " # this is a helper function to help compilers generate bytecode to read capture variables from closures, rather than reading values from global scope. The value of these variables does not matter, and will be determined in runtime.\n" + ) + for freevar in freevars: + tmp_code += f" {freevar} = None\n" + tmp_code += add_indentation(new_code, 4) + new_code = tmp_code + + # make sure the new bytecode has at least the same number of local variables as the original bytecode + # this seems to fix the test failure in https://github.com/thuml/depyf/actions/runs/7004325219/job/19051829613 , and might be related with the discussion in https://github.com/pytorch/pytorch/pull/111883 + compiled_code = compile(new_code, "noname", "exec") + from .utils import collect_all_code_objects + code_objects = collect_all_code_objects(compiled_code) + target_code = [x for x in code_objects if x.co_name == function_name][0] + + missing_local_variables = set(old_bytecode.co_varnames) - set(target_code.co_varnames) + missing_cellvars = set(old_bytecode.co_cellvars) - set(target_code.co_cellvars) + + if missing_local_variables or missing_cellvars: + return fix_irregular_code( + old_bytecode, src_code, + add_local_variables=sorted(list(missing_local_variables)), + add_cellvars=sorted(list(missing_cellvars))) + return new_code diff --git a/.venv/lib/python3.11/site-packages/depyf/decompiler.py b/.venv/lib/python3.11/site-packages/depyf/decompiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6caf20b0c7bab0e109a5fc17c5c0bde11c7b7ff4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/decompiler.py @@ -0,0 +1,1312 @@ +"""A simple program to transform bytecode into more readable source code.""" + +import sys +import os +import dis +from types import CodeType +from typing import List, Tuple, Dict, Union, Callable, Optional +import dataclasses +import inspect +import functools +from collections import defaultdict +import contextlib + +from .code_transform import ( + nop_unreachable_bytecode, + nop_instruction, + add_indentation, + remove_indentation, + remove_some_temp, + propagate_line_nums, + convert_instruction, + simplify_finally_statement, + Instruction, +) +from .utils import ( + get_function_signature, +) + + +class DecompilationError(Exception): + """Custom exception class for decompilation.""" + + def __init__(self, message=""): + self.message = message + super().__init__(self.message) + + def __str__(self): + return f'DecompilationError: {self.message}' + + +@dataclasses.dataclass +class DecompilerState: + """State of decompiler, keep track of the evaluation stack, as well as the decompiled source code.""" + source_code: str + stack: list + inside_loop: bool = False + loop_start_index: int = -1 # inclusive + loop_end_index: int = -1 # exclusive + + +@dataclasses.dataclass +class Decompiler: + """A decompiler for a code object.""" + code: CodeType + temp_count: int = 0 + temp_prefix: str = "__temp_" + state: DecompilerState = dataclasses.field( + default_factory=lambda: DecompilerState( + source_code="", stack=[])) + indentation: int = 4 + + @contextlib.contextmanager + def new_state(self, stack, inside_loop=False, loop_start_index=-1, loop_end_index=-1): + """Create a new state for decompiler.""" + state = DecompilerState(source_code="", stack=stack, inside_loop=inside_loop, loop_start_index=loop_start_index, loop_end_index=loop_end_index) + old_state = self.state + if old_state.inside_loop and not state.inside_loop: + # inherit the loop state from the old state + state.inside_loop = old_state.inside_loop + state.loop_start_index = old_state.loop_start_index + state.loop_end_index = old_state.loop_end_index + self.state = state + yield + self.state = old_state + + +# ==================== Unsupported Instructions ============================= + def unimplemented_instruction(self, inst: Instruction): + raise NotImplementedError(f"Unsupported instruction: {inst.opname}") + + GET_YIELD_FROM_ITER = unimplemented_instruction + + # we don't support try-except/try-finally + POP_EXCEPT = WITH_EXCEPT_START = JUMP_IF_NOT_EXC_MATCH = CHECK_EG_MATCH = PUSH_EXC_INFO = PREP_RERAISE_STAR = WITH_CLEANUP_FINISH = CALL_FINALLY = POP_FINALLY = WITH_CLEANUP_START = SETUP_EXCEPT = CHECK_EXC_MATCH = CLEANUP_THROW = unimplemented_instruction + + # we don't support async/await + GET_AWAITABLE = GET_AITER = GET_ANEXT = END_ASYNC_FOR = BEFORE_ASYNC_WITH = SETUP_ASYNC_WITH = SEND = ASYNC_GEN_WRAP = unimplemented_instruction + + CACHE = unimplemented_instruction + + # we don't know these instructions + PRINT_EXPR = COPY_DICT_WITHOUT_KEYS = unimplemented_instruction + + # we only support bytecode for functions + IMPORT_STAR = unimplemented_instruction + + YIELD_FROM = SETUP_ANNOTATIONS = LOAD_BUILD_CLASS = MATCH_MAPPING = MATCH_SEQUENCE = MATCH_KEYS = MATCH_CLASS = unimplemented_instruction + + # don't find any interesting use case for these instructions + CALL_INTRINSIC_2 = unimplemented_instruction + + +# ==================== NOP Instructions ============================= + + def generic_nop(self, inst: Instruction): + pass + + # "EXTENDED_ARG" is treated as NOP here, because it has been handled by `dis.get_instructions`. + # The extended args are already merged into the following instruction's + # `inst.argval`. + EXTENDED_ARG = generic_nop + + NOP = RESUME = SETUP_LOOP = POP_BLOCK = PRECALL = BEGIN_FINALLY = END_FINALLY = generic_nop + + MAKE_CELL = generic_nop + + RERAISE = generic_nop + + # our FOR_ITER is different from CPython's FOR_ITER (as it does not need + # to explicitly consider the case of exhausted iterator), so we don't need + # to do anything here + END_FOR = generic_nop + + +# ==================== Load Instructions ============================= + + def LOAD_CONST(self, inst: Instruction): + """Push a constant onto the stack. + `inst.argval` is the constant value, we have to use `repr` to get the source code + """ + can_repr = False + try: + can_repr = eval(repr(inst.argval)) == inst.argval + except BaseException: + pass + if can_repr: + self.state.stack.append(repr(inst.argval)) + else: + if isinstance(inst.argval, type): + # Don't know why a class type get here, support this corner + # case anyway. + module = inst.argval.__module__ + name = inst.argval.__name__ + self.state.source_code += "import importlib\n" + temp_name = self.get_temp_name() + self.state.source_code += f'{temp_name} = importlib.import_module("{module}").{name}\n' + self.state.stack.append(temp_name) + elif inst.argrepr.startswith("torch."): + # Don't know why torch.xxx get here, support this corner case + # anyway. This deals with something like `torch.float`. + self.state.source_code += "import torch\n" + temp_name = self.get_temp_name() + self.state.source_code += f'{temp_name} = {inst.argval}\n' + self.state.stack.append(temp_name) + elif isinstance(inst.argval, CodeType): + # used in MAKE_FUNCTION + self.state.stack.append(inst.argval) + else: + self.state.stack.append(f"'__co_consts[{inst.arg}]'") + + def generic_load(self, inst: Instruction): + """`inst.argval` is the variable name, in string""" + if "NULL + " in inst.argrepr: + # Python 3.11 support + self.state.stack.append(None) + if inst.argrepr.startswith("."): + # list/set/tuple comprehension. + self.state.stack.append(inst.argval.replace(".", "comp_arg_")) + else: + self.state.stack.append(inst.argval) + + LOAD_FAST = LOAD_FAST_AND_CLEAR = LOAD_FAST_CHECK = LOAD_GLOBAL = LOAD_DEREF = LOAD_NAME = LOAD_CLASSDEREF = LOAD_CLOSURE = generic_load + + def LOAD_LOCALS(self, inst: Instruction): + self.state.stack.append("locals()") + self.replace_mutable_tos_with_temp() + + def LOAD_FROM_DICT_OR_GLOBALS(self, inst: Instruction): + tos = self.state.stack.pop() + self.state.stack.append( + f"{tos}[{inst.argval}] if '{inst.argval}' in {tos} else {inst.argval}") + self.replace_mutable_tos_with_temp() + + LOAD_FROM_DICT_OR_DEREF = LOAD_FROM_DICT_OR_GLOBALS + + def MAKE_FUNCTION(self, inst: Instruction): + if sys.version_info < (3, 11): + qual_name = self.state.stack.pop() + try: + qual_name = eval(qual_name) + except Exception: + pass + # qual_name for inner function is something like `LongformerEncoder.forward..create_custom_forward` + # get the last part of the name, which is the function name + func_name = qual_name.split(".")[-1] + if "<" in func_name: + self.state.source_code += f'"original function name {func_name} is illegal, use a temp name."\n' + func_name = self.get_temp_name() + else: + # Python 3.11 support, see + # https://docs.python.org/3.11/library/dis.html#opcode-MAKE_FUNCTION + func_name = self.get_temp_name() + code = self.state.stack.pop() + if inst.argval & 0x08: + # has closure + self.state.stack.pop() + if inst.argval & 0x04: + # has annotations + self.state.stack.pop() + kw_defaults = self.state.stack.pop() if inst.argval & 0x02 else {} + defaults = self.state.stack.pop() if inst.argval & 0x01 else () + if len(kw_defaults) or len(defaults): + print( + "Function with default arguments is not supported, ignore the default arguments") + this_index = self.index_of(inst.offset) + immediately_used = False + if self.instructions[this_index + 1].opname == "STORE_FAST": + # the function is immediately stored in a variable, use that + # variable name + func_name = self.instructions[this_index + 1].argval + immediately_used = True + inner_func = Decompiler(code).decompile(overwite_fn_name=func_name) + self.state.source_code += inner_func + if not immediately_used: + self.state.stack.append(func_name) + else: + # skip one instruction + return this_index + 2 + + def COPY_FREE_VARS(self, inst: Instruction): + # this opcode is used to copy free variables from the outer scope to the closure + # it affects the frame, but not the stack or the source code + pass + + def LOAD_ATTR(self, inst: Instruction): + lhs = str(self.state.stack.pop()) + rhs = inst.argval + if rhs.isidentifier(): + self.state.stack.append(f"{lhs}.{rhs}") + else: + self.state.stack.append(f"getattr({lhs}, {repr(rhs)})") + + def LOAD_SUPER_ATTR(self, inst: Instruction): + # not tested + self_obj = self.state.stack.pop() + cls_obj = self.state.stack.pop() + super_obj = self.state.stack.pop() + self.state.stack.append( + f"{super_obj}({cls_obj}, {self_obj}).{inst.argval}") + self.replace_mutable_tos_with_temp() + + def LOAD_METHOD(self, inst: Instruction): + self.state.stack.append(f"{self.state.stack.pop()}.{inst.argval}") + + def LOAD_ASSERTION_ERROR(self, inst: Instruction): + self.state.stack.append("AssertionError") + + def PUSH_NULL(self, inst: Instruction): + # the `None` object is used to represent `NULL` in python bytecode + self.state.stack.append(None) + + def GET_ITER(self, inst: Instruction): + tos = self.state.stack.pop() + self.state.stack.append(f"iter({tos})") + +# ==================== Store Instructions ============================= + + def generic_store(self, inst: Instruction): + left = inst.argval + right = self.state.stack.pop() + if left != right: + # Inplace operations like `+=` will pop the variable name from the stack, and push the result back to the stack + # leading to a source code like `x = x`. We need to avoid this. + self.state.source_code += f"{left} = {right}\n" + + STORE_FAST = STORE_GLOBAL = STORE_DEREF = STORE_NAME = generic_store + + def STORE_SUBSCR(self, inst: Instruction): + index = self.state.stack.pop() + x = self.state.stack.pop() + value = self.state.stack.pop() + self.state.source_code += f"{x}[{index}] = {value}\n" + + def STORE_SLICE(self, inst: Instruction): + # not tested, code according to + # https://docs.python.org/3.12/library/dis.html#opcode-STORE_SLICE + end = self.state.stack.pop() + start = self.state.stack.pop() + container = self.state.stack.pop() + value = self.state.stack.pop() + self.state.source_code += f"{container}[{start}:{end}] = {value}\n" + + def STORE_ATTR(self, inst: Instruction): + x = self.state.stack.pop() + value = self.state.stack.pop() + self.state.source_code += f"{x}.{inst.argval} = {value}\n" + +# ==================== Del Instructions ============================= + + def DELETE_SUBSCR(self, inst: Instruction): + index = self.state.stack.pop() + x = self.state.stack.pop() + self.state.source_code += f"del {x}[{index}]\n" + + def generic_delete(self, inst: Instruction): + self.state.source_code += f"del {inst.argval}\n" + + DELETE_NAME = DELETE_GLOBAL = DELETE_DEREF = generic_delete + # `DELETE_FAST` just reduces the ref count by one + # it does not occur as code `del x` in the source code + DELETE_FAST = generic_nop + + def DELETE_ATTR(self, inst: Instruction): + x = self.state.stack.pop() + self.state.source_code += f"del {x}.{inst.argval}\n" + +# ==================== Import Instructions ============================= + def IMPORT_NAME(self, inst: Instruction): + # TODO: check multi-level import, e.g. `import a.b.c` + name = inst.argval.split(".")[0] + fromlist = self.state.stack.pop() + level = self.state.stack.pop() + self.state.source_code += f"{name} = __import__({repr(inst.argval)}, fromlist={fromlist}, level={level})\n" + self.state.stack.append(name) + + def IMPORT_FROM(self, inst: Instruction): + name = inst.argval + module = self.state.stack[-1] + self.state.source_code += f"{name} = {module}.{name}\n" + self.state.stack.append(name) + +# ==================== Unary Instructions ============================= + + def generic_unary(self, inst: Instruction): + op = { + "UNARY_NEGATIVE": "-", + "UNARY_POSITIVE": "+", + "UNARY_INVERT": "~", + "UNARY_NOT": "not", + }[inst.opname] + self.state.stack.append(f"({op} {self.state.stack.pop()})") + + UNARY_NEGATIVE = UNARY_POSITIVE = UNARY_INVERT = UNARY_NOT = generic_unary + + def GET_LEN(self, inst: Instruction): + self.state.stack.append(f"len({self.state.stack[-1]})") + +# ==================== Binary Instructions ============================= + def generic_binary(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + op = { + "BINARY_MULTIPLY": "*", + "BINARY_ADD": "+", + "BINARY_SUBTRACT": "-", + "BINARY_TRUE_DIVIDE": "/", + "BINARY_FLOOR_DIVIDE": "//", + "BINARY_MODULO": "%", + "BINARY_POWER": "**", + "BINARY_AND": "&", + "BINARY_OR": "|", + "BINARY_XOR": "^", + "BINARY_LSHIFT": "<<", + "BINARY_RSHIFT": ">>", + "BINARY_MATRIX_MULTIPLY": "@", + }[inst.opname] + self.state.stack.append(f"({lhs} {op} {rhs})") + + BINARY_MULTIPLY = BINARY_ADD = BINARY_SUBTRACT = BINARY_TRUE_DIVIDE = BINARY_FLOOR_DIVIDE = BINARY_MODULO = BINARY_POWER = BINARY_AND = BINARY_OR = BINARY_XOR = BINARY_LSHIFT = BINARY_RSHIFT = BINARY_MATRIX_MULTIPLY = generic_binary + + def BINARY_SUBSCR(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + self.state.stack.append(f"{lhs}[{rhs}]") + + def BINARY_SLICE(self, inst: Instruction): + end = self.state.stack.pop() + start = self.state.stack.pop() + container = self.state.stack.pop() + self.state.stack.append(f"{container}[{start}:{end}]") + +# ==================== Binary Inplace Instructions ======================= + def generic_inplace_binary(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + op = { + "INPLACE_MULTIPLY": "*", + "INPLACE_ADD": "+", + "INPLACE_SUBTRACT": "-", + "INPLACE_TRUE_DIVIDE": "/", + "INPLACE_FLOOR_DIVIDE": "//", + "INPLACE_MODULO": "%", + "INPLACE_POWER": "**", + "INPLACE_AND": "&", + "INPLACE_OR": "|", + "INPLACE_XOR": "^", + "INPLACE_LSHIFT": "<<", + "INPLACE_RSHIFT": ">>", + "INPLACE_MATRIX_MULTIPLY": "@", + }[inst.opname] + self.state.source_code += f"{lhs} {op}= {rhs}\n" + self.state.stack.append(lhs) + + INPLACE_MULTIPLY = INPLACE_ADD = INPLACE_SUBTRACT = INPLACE_TRUE_DIVIDE = INPLACE_FLOOR_DIVIDE = INPLACE_MODULO = INPLACE_POWER = INPLACE_AND = INPLACE_OR = INPLACE_XOR = INPLACE_LSHIFT = INPLACE_RSHIFT = INPLACE_MATRIX_MULTIPLY = generic_inplace_binary + + def BINARY_OP(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + if "=" in inst.argrepr: + self.state.source_code += f"{lhs} {inst.argrepr} {rhs}\n" + self.state.stack.append(lhs) + else: + self.state.stack.append(f"({lhs} {inst.argrepr} {rhs})") + +# ==================== Conditional Test Instructions ===================== + def COMPARE_OP(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + self.state.stack.append(f"({lhs} {inst.argval} {rhs})") + + def IS_OP(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + op = "is" if inst.argval == 0 else "is not" + self.state.stack.append(f"({lhs} {op} {rhs})") + + def CONTAINS_OP(self, inst: Instruction): + rhs = self.state.stack.pop() + lhs = self.state.stack.pop() + op = "in" if inst.argval == 0 else "not in" + self.state.stack.append(f"({lhs} {op} {rhs})") + +# ==================== Control Flow Instructions ============================= + + def BREAK_LOOP(self, inst: Instruction): + self.state.source_code += "break\n" + + def generic_abs_jump(self, inst: Instruction): + jump_offset = inst.get_jump_target() + jump_index = self.index_of(jump_offset) + if self.state.inside_loop: + if jump_index >= self.state.loop_end_index: + self.state.source_code += "break\n" + elif jump_index <= self.state.loop_start_index: + self.state.source_code += "continue\n" + else: + return jump_index + else: + return jump_index + + JUMP_ABSOLUTE = JUMP_FORWARD = JUMP_BACKWARD = JUMP_BACKWARD_NO_INTERRUPT = generic_abs_jump + + def RETURN_VALUE(self, inst: Instruction): + self.state.source_code += f"return {self.state.stack[-1]}\n" + self.state.stack.pop() + + def RETURN_CONST(self, inst: Instruction): + self.state.source_code += f"return {inst.argval}\n" + + def YIELD_VALUE(self, inst: Instruction): + if sys.version_info >= (3, 12): + raise NotImplementedError( + "YIELD_VALUE is not supported in Python 3.12") + self.state.source_code += f"yield {self.state.stack[-1]}\n" + + def RETURN_GENERATOR(self, inst: Instruction): + # we don't handle generator/coroutine, add this to support simple yield + self.state.stack.append(None) + + def GEN_START(self, inst: Instruction): + # self.state.stack.pop() + assert inst.argval == 0, "Only generator expression is supported" + + def generic_jump_if(self, inst: Instruction): + """How we support if-else: + + Failed idea: try to paritition the block of instructions into if and else. + This is not possible, as the if-else block might have overlapping instructions. + Take this function as an example: + + def f(a): + b = 1 if a else 2 + print(b) + + The bytecode is: + 2 0 LOAD_FAST 0 (a) + 2 POP_JUMP_IF_FALSE 4 (to 8) + 4 LOAD_CONST 1 (1) + 6 JUMP_FORWARD 1 (to 10) + >> 8 LOAD_CONST 2 (2) + >> 10 STORE_FAST 1 (b) + + 3 12 LOAD_GLOBAL 0 (print) + 14 LOAD_FAST 1 (b) + 16 CALL_FUNCTION 1 + 18 POP_TOP + 20 LOAD_CONST 0 (None) + 22 RETURN_VALUE + + The instructions for if branch: 2, 4, 6, 10 + The instructions for else branch: 8, 10 + They share the same instruction 10, so we cannot partition the block into if and else. + + Another example: + + def f(): + g(arg1=a if a is not None else b, arg2=2) + print(1) + + The bytecode is: + + 2 0 LOAD_GLOBAL 0 (g) + 2 LOAD_GLOBAL 1 (a) + 4 LOAD_CONST 0 (None) + 6 IS_OP 1 + 8 POP_JUMP_IF_FALSE 7 (to 14) + 10 LOAD_GLOBAL 1 (a) + 12 JUMP_FORWARD 1 (to 16) + >> 14 LOAD_GLOBAL 2 (b) + >> 16 LOAD_CONST 1 (2) + 18 LOAD_CONST 2 (('arg1', 'arg2')) + 20 CALL_FUNCTION_KW 2 + 22 POP_TOP + + 3 24 LOAD_GLOBAL 3 (print) + 26 LOAD_CONST 3 (1) + 28 CALL_FUNCTION 1 + 30 POP_TOP + 32 LOAD_CONST 0 (None) + 34 RETURN_VALUE + + The instructions for if branch: 8, 14, 16, 18, 20, 22 + The instructions for else branch: 10, 12, 16, 18, 20, 22 + They share the same instructions 16, 18, 20, 22, so we cannot partition the block into if and else. + + Current idea: + + We take advantage of the following fact: + + This code snippet: + + if cond: + if-body + else: + else-body + rest-body + + is equivalent to: + + if cond: + if-body + rest-body + else: + else-body + rest-body + + By duplicating the rest-body, we can decompile the if-else block separately. And they will have some duplicated code. + + Of course, we don't want to duplicate too long code, so we need to find the end of if-else block. + The current heuristic is to find the first store/return/jump/for-iter instruction after the if-else block (because they are indicators that we will generate meaningful source code). + """ + jump_offset = inst.get_jump_target() + jump_index = self.index_of(jump_offset) + this_index = self.index_of(inst.offset) + cond = self.state.stack[-1] + fallthrough_stack = self.state.stack.copy() + jump_stack = self.state.stack.copy() + + if "IF_NOT_NONE" in inst.opname: + cond = f"({cond} is None)" + elif "IF_NONE" in inst.opname: + cond = f"({cond} is not None)" + elif "IF_TRUE" in inst.opname: + cond = f"(not {cond})" + elif "IF_FALSE" in inst.opname: + cond = f"{cond}" + + # POP_AND_JUMP / JUMP_OR_POP + if "POP_JUMP" in inst.opname: + jump_stack.pop() + fallthrough_stack.pop() + elif "OR_POP" in inst.opname: + fallthrough_stack.pop() + + end_index_candidates = [len(self.instructions)] + if self.state.inside_loop: + end_index_candidates.append(self.state.loop_end_index) + + def qualified_jump(i: Instruction): + return i.is_jump() and i.get_jump_target() >= jump_offset + + jump_targets = [i.get_jump_target() for i in self.instructions[this_index: jump_index] if qualified_jump(i)] + + if not jump_targets: + # this is a jump back, we will generate a ``continue`` statement + # normally `if` condition is for the fallthrough code, but in this case + # we need to generate the `if` condition for the jump code + # therefore the condition is reversed + cond = self.state.stack[-1] + if "IF_NOT_NONE" in inst.opname: + cond = f"({cond} is not None)" + elif "IF_NONE" in inst.opname: + cond = f"({cond} is None)" + elif "IF_TRUE" in inst.opname: + cond = f"{cond}" + elif "IF_FALSE" in inst.opname: + cond = f"(not {cond})" + if_code = f"if {cond}:\n" + add_indentation("continue\n", self.indentation) + self.state.source_code += if_code + return + + max_jump = max(jump_targets) + max_jump_index = self.index_of(max_jump) + # else branch might have jumps, we need to find the end of the else + all_jump_targets = [i.get_jump_target() for i in self.instructions[this_index: max_jump_index] if qualified_jump(i)] + max_jump_index = self.index_of(max(all_jump_targets)) + last_inst = self.instructions[max_jump_index - 1] + if "RAISE" in last_inst.opname or "RETURN" in last_inst.opname or "STORE" in last_inst.opname: + # if-body instructions end with raise/return/store, it is very likely that if-body and else-body don't share any instructions + pass + else: + old_map_jump_index = max_jump_index + while max_jump_index < len(self.instructions): + opname = self.instructions[max_jump_index].opname + if "STORE" in opname or "RETURN" in opname: + # we want to include the store/return instruction in the if-else block + max_jump_index += 1 + break + elif ("JUMP" in opname and max_jump_index > old_map_jump_index) or "FOR_ITER" in opname: + # we don't want to include the jump instruction in the if-else block + break + max_jump_index += 1 + end_index_candidates.append(max_jump_index) + + end_index = min(end_index_candidates) + + with self.new_state(fallthrough_stack): + self.decompile_range(this_index + 1, end_index) + if_body = self.state.source_code + if_body = add_indentation(if_body, self.indentation) + if_end_stack = self.state.stack.copy() + if_code = f"if {cond}:\n{if_body}" + self.state.source_code += if_code + + with self.new_state(jump_stack): + self.decompile_range(jump_index, end_index) + else_body = self.state.source_code + if else_body: + else_body = add_indentation(else_body, self.indentation) + else_code = f"else:\n{else_body}" + self.state.source_code += else_code + + self.state.stack = if_end_stack + return end_index + + + POP_JUMP_IF_TRUE = POP_JUMP_IF_FALSE = generic_jump_if + POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_FORWARD_IF_FALSE = generic_jump_if + POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_BACKWARD_IF_FALSE = generic_jump_if + POP_JUMP_FORWARD_IF_NONE = POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump_if + POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump_if + JUMP_IF_TRUE_OR_POP = JUMP_IF_FALSE_OR_POP = generic_jump_if + POP_JUMP_IF_NOT_NONE = POP_JUMP_BACKWARD_IF_NOT_NONE + POP_JUMP_IF_NONE = POP_JUMP_BACKWARD_IF_NONE + + def SETUP_FINALLY(self, inst: Instruction): + start_index = self.index_of(inst.offset) + end_index = self.index_of(inst.get_jump_target()) + pop_block_index = [i for i, x in enumerate( + self.instructions) if x.opname == "POP_BLOCK" and start_index <= i < end_index][-1] + + try_code = "" + with self.new_state(self.state.stack): + self.decompile_range(start_index + 1, pop_block_index) + try_code = self.state.source_code + try_code = add_indentation(try_code, self.indentation) + try_code = "try:\n" + try_code + + finally_code = "" + with self.new_state(self.state.stack): + end_finally_index = [ + i for i, x in enumerate( + self.instructions) if x.opname == "END_FINALLY" and start_index <= i] + if end_finally_index: + end_index = end_finally_index[0] + finally_end_index = end_index + if self.instructions[finally_end_index - 1].is_jump(): + finally_end_index -= 1 + self.decompile_range(pop_block_index + 1, finally_end_index) + finally_code = self.state.source_code + finally_code = add_indentation(finally_code, self.indentation) + finally_code = "finally:\n" + finally_code + + self.state.source_code += try_code + finally_code + return end_index + + def SETUP_WITH(self, inst: Instruction): + """ + with expression as var: + body + + is equivalent to: + + var = expression + var.__enter__() + try: + body + finally: + var.__exit__() + + We find the start of `finally` by `WITH_EXCEPT_START`, and the end of `finally` by `POP_EXCEPT`. + In early python version, the start is `WITH_CLEANUP_START` and the end is `WITH_CLEANUP_FINISH`. + """ + start_index = self.index_of(inst.offset) + with_except_index = [i for i, x in enumerate( + self.instructions) if x.opname in ["WITH_EXCEPT_START", "WITH_CLEANUP_START"] and i > start_index][-1] + end_index = with_except_index + nop_instruction(self.instructions[end_index]) + + # NOP PUSH_EXC_INFO and JUMP_FORWARD + i = end_index - 1 + while end_index - i <= 2: + _inst = self.instructions[i] + if _inst.opname.startswith("JUMP") or _inst.opname == "PUSH_EXC_INFO": + nop_instruction(_inst) + i -= 1 + + pop_except_indices = [i for i, x in enumerate( + self.instructions) if x.opname in ["POP_EXCEPT", "WITH_CLEANUP_FINISH"] and i > end_index] + if sys.version_info >= (3, 11): + # Python 3.11 seems to have two `POP_EXCEPT` instructions, not sure why. + pop_except_index = pop_except_indices[1] + else: + pop_except_index = pop_except_indices[0] + for i in range(end_index, pop_except_index + 1): + nop_instruction(self.instructions[i]) + tos = self.state.stack[-1] + temp = self.get_temp_name() + self.state.stack.append(f"{temp}.__exit__") + self.state.stack.append(temp) + with_clause = f"with {tos} as {temp}:\n" + with_body = "" + with self.new_state(self.state.stack): + self.decompile_range(start_index + 1, end_index) + with_body = self.state.source_code + with_body = add_indentation(with_body, self.indentation) + lines = with_body.splitlines() + ans = [] + for line in lines: + if f"{temp}.__exit__" in line or "None(None, None)" in line.strip(): + # this is the line that calls __exit__, we need to remove it, as it is managed by `with` statement. + # `None(None, None)` is used for Python 3.11. Who knows why it loads three Nones but call with 2 args for the following simple code: + # def f(): + # with a: + # print(2) + continue + ans.append(line) + with_body = "".join([x + "\n" for x in ans]) + + self.state.source_code += with_clause + with_body + return pop_except_index + 1 + + BEFORE_WITH = SETUP_WITH + + def FOR_ITER(self, inst: Instruction): + start_index = self.index_of(inst.offset) + end_index = self.index_of(inst.get_jump_target()) + + temp_name = self.get_temp_name() + for_code = f"for {temp_name} in {self.state.stack.pop()}:\n" + self.state.stack.append(temp_name) + last_inst = self.instructions[end_index] + if last_inst.is_jump() and last_inst.get_jump_target() == inst.offset: + # if end_index is something like jumping back to for_iter, + # we should deal with it inside the loop + end_index += 1 + with self.new_state(self.state.stack, inside_loop=True, loop_start_index=start_index, loop_end_index=end_index): + self.decompile_range(start_index + 1, end_index) + code = self.state.source_code + for_code = for_code + add_indentation(code, self.indentation) + for_end_stack = self.state.stack.copy() + self.state.source_code += for_code + self.state.stack = for_end_stack + return end_index + +# ==================== Stack Manipulation Instructions =================== + def rot_n(self, inst: Instruction): + if inst.opname == "ROT_N": + n = inst.argval + else: + n = { + "ROT_TWO": 2, + "ROT_THREE": 3, + "ROT_FOUR": 4, + }[inst.opname] + values = self.state.stack[-n:] + values = [values[-1]] + values[:-1] + self.state.stack[-n:] = values + + ROT_N = ROT_TWO = ROT_THREE = ROT_FOUR = rot_n + + def SWAP(self, inst: Instruction): + n = inst.argval + tos = self.state.stack[-1] + value = self.state.stack[- n] + tos, value = value, tos + self.state.stack[-1] = tos + self.state.stack[- n] = value + + def COPY(self, inst: Instruction): + # not tested, don't know how to generate this instruction + n = inst.argval + value = self.state.stack[-1 - n] + self.state.stack.append(value) + + def POP_TOP(self, inst: Instruction): + self.state.stack.pop() + + def DUP_TOP(self, inst: Instruction): + # not tested + self.state.stack.append(self.state.stack[-1]) + + def DUP_TOP_TWO(self, inst: Instruction): + # not tested + tos = self.state.stack[-1] + tos1 = self.state.stack[-2] + self.state.stack.append(tos1) + self.state.stack.append(tos) + +# ==================== Function Call Instructions ============================= + def KW_NAMES(self, inst: Instruction): + names = self.code.co_consts[inst.arg] + self.state.stack.append(repr(names)) + + def CALL(self, inst: Instruction): + last_inst = [x for x in self.instructions if x.offset < inst.offset] + has_kw_names = False + if last_inst: + if last_inst[-1].opname == "KW_NAMES" or (len( + last_inst) > 1 and last_inst[-2].opname == "KW_NAMES" and last_inst[-1].opname == "PRECALL"): + has_kw_names = True + kw_names = tuple() + if has_kw_names: + kw_names = eval(self.state.stack.pop()) + args = [(self.state.stack.pop()) for _ in range(inst.argval)] + args = args[::-1] + pos_args = args[:len(args) - len(kw_names)] + kwargs = args[len(args) - len(kw_names):] + kwcalls = [] + for name, value in zip(kw_names, kwargs): + kwcalls.append(f"{name}={value}") + func = self.state.stack.pop() + if self.state.stack and self.state.stack[-1] is None: + self.state.stack.pop() + if "iter(" in func: + # Why do we need this? Don't know. But sometimes CPython generates + # CALL with argval=0, but the function actually needs an arg (for + # list/set/map comprehension). + pos_args = [func] + func = self.state.stack.pop() + self.state.stack.append(f"{func}({', '.join(pos_args + kwcalls)})") + self.replace_mutable_tos_with_temp() + + def generic_call(self, inst: Instruction): + args = [(self.state.stack.pop()) for _ in range(inst.argval)] + args = args[::-1] + func = self.state.stack.pop() + self.state.stack.append(f"{func}({', '.join(args)})") + self.replace_mutable_tos_with_temp() + + CALL_FUNCTION = CALL_METHOD = generic_call + + def CALL_FUNCTION_KW(self, inst: Instruction): + kw_args = eval(self.state.stack.pop()) + kw_vals = [(self.state.stack.pop()) for _ in range(len(kw_args))] + kw_vals.reverse() + kwcalls = [] + for name, val in zip(kw_args, kw_vals): + kwcalls.append(f"{name}={val}") + pos_args = [(self.state.stack.pop()) + for _ in range(inst.argval - len(kw_args))] + pos_args = pos_args[::-1] + func = self.state.stack.pop() + self.state.stack.append(f"{func}({', '.join(pos_args + kwcalls)})") + self.replace_mutable_tos_with_temp() + + def CALL_FUNCTION_EX(self, inst: Instruction): + if inst.argval == 0: + args = self.state.stack.pop() + func = self.state.stack.pop() + self.state.stack.append(f"{func}(*{args})") + elif inst.argval == 1: + kw_args = self.state.stack.pop() + args = self.state.stack.pop() + func = self.state.stack.pop() + self.state.stack.append(f"{func}(*{args}, **{kw_args})") + self.replace_mutable_tos_with_temp() + + def CALL_INTRINSIC_1(self, inst: Instruction): + if inst.argrepr in [ + "INTRINSIC_1_INVALID", + "INTRINSIC_IMPORT_STAR", + "INTRINSIC_STOPITERATION_ERROR", + "INTRINSIC_ASYNC_GEN_WRAP"]: + # invalid intrinsic, skip + pass + elif inst.argrepr in ["INTRINSIC_TYPEVAR", "INTRINSIC_PARAMSPEC", "INTRINSIC_TYPEVARTUPLE", "INTRINSIC_SUBSCRIPT_GENERIC", "INTRINSIC_TYPEALIAS"]: + # not tested, skip + pass + elif inst.argrepr == "INTRINSIC_PRINT": + self.state.source_code += f"print({self.state.stack.pop()})\n" + self.state.stack.append("None") + elif inst.argrepr == "INTRINSIC_UNARY_POSITIVE": + self.state.stack[-1] = f"+{self.state.stack[-1]}" + elif inst.argrepr == "INTRINSIC_LIST_TO_TUPLE": + return self.LIST_TO_TUPLE(inst) + + +# ==================== Container Related Instructions (tuple, list, set, d + + def UNPACK_SEQUENCE(self, inst: Instruction): + # sequence can be tuple, list, or even generator + # we cannot directly use indexing to get the elements + # because the sequence might be a generator (not subscriptable) + # instead, we use a temporary variable to store the unpacked elements + + # e.g. `a, b = (None for _ in (1, 2))` + # will be transformed into: + # __temp_1 = (None for _ in (1, 2)) + # __temp_2, __temp_3 = __temp_1 + # a = __temp_2 + # b = __temp_3 + varname = self.state.stack.pop() + tmp_names = [] + for i in range(inst.argval): + tmp_names.append(self.get_temp_name()) + # NOTE: even if there is only one element, we still need to unpack it + # a = b is different from a, = b + lhs = "".join([f"{x}, " for x in tmp_names]) + self.state.source_code += lhs + f"= {varname}\n" + for name in tmp_names[::-1]: + self.state.stack.append(name) + + def UNPACK_EX(self, inst: Instruction): + varname = self.state.stack.pop() + tmp_names = [] + for i in range(inst.argval): + tmp_names.append(self.get_temp_name()) + star_name = self.get_temp_name() + self.state.source_code += ", ".join(tmp_names) + f", *{star_name}" + f" = {varname}\n" + self.state.stack.append(star_name) + for name in tmp_names[::-1]: + self.state.stack.append(name) + + def BUILD_SLICE(self, inst: Instruction): + tos = self.state.stack.pop() + tos1 = self.state.stack.pop() + if inst.argval == 2: + self.state.stack.append(f"slice({tos1}, {tos})") + elif inst.argval == 3: + tos2 = self.state.stack.pop() + self.state.stack.append(f"slice({tos2}, {tos1}, {tos})") + + def build_tuple(self, inst: Instruction): + args = [self.state.stack.pop() for _ in range(inst.argval)] + args = args[::-1] + if "UNPACK" in inst.opname: + args = [f"*{arg}" for arg in args] + if inst.argval == 1: + self.state.stack.append(f"({args[0]},)") + else: + self.state.stack.append(f"({', '.join(args)})") + + BUILD_TUPLE = BUILD_TUPLE_UNPACK = BUILD_TUPLE_UNPACK_WITH_CALL = build_tuple + + def build_list(self, inst: Instruction): + args = [self.state.stack.pop() for _ in range(inst.argval)] + args = args[::-1] + if "UNPACK" in inst.opname: + args = [f"*{arg}" for arg in args] + self.state.stack.append(f"[{', '.join(args)}]") + self.replace_mutable_tos_with_temp() + + BUILD_LIST = BUILD_LIST_UNPACK = build_list + + def build_set(self, inst: Instruction): + ans = "" + if inst.argval == 0: + ans = "set()" + else: + args = [self.state.stack.pop() for _ in range(inst.argval)] + args = args[::-1] + if "UNPACK" in inst.opname: + args = [f"*{arg}" for arg in args] + ans = f"{{{', '.join(args)}}}" + self.state.stack.append(ans) + self.replace_mutable_tos_with_temp() + + BUILD_SET = BUILD_SET_UNPACK = build_set + + def build_map_unpack(self, inst: Instruction): + if inst.argval == 0: + self.state.stack.append("dict()") + else: + args = [self.state.stack.pop() for _ in range(inst.argval)] + args = args[::-1] + args = [f"**{arg}" for arg in args] + self.state.stack.append(f"{{{', '.join(args)}}}") + self.replace_mutable_tos_with_temp() + + BUILD_MAP_UNPACK = BUILD_MAP_UNPACK_WITH_CALL = build_map_unpack + + def BUILD_MAP(self, inst: Instruction): + args = [self.state.stack.pop() for _ in range(inst.argval * 2)] + args = args[::-1] + keys = args[::2] + values = args[1::2] + self.state.stack.append( + f"{{{', '.join([f'{k}: {v}' for k, v in zip(keys, values)])}}}") + self.replace_mutable_tos_with_temp() + + def BUILD_CONST_KEY_MAP(self, inst: Instruction): + keys = eval(self.state.stack.pop()) + args = [self.state.stack.pop() for _ in range(inst.argval)] + values = args[::-1] + self.state.stack.append( + f"{{{', '.join([f'{k}: {v}' for k, v in zip(keys, values)])}}}") + self.replace_mutable_tos_with_temp() + + def BUILD_STRING(self, inst: Instruction): + args = [self.state.stack.pop() for _ in range(inst.argval)] + args = args[::-1] + values = " + ".join(args) + self.state.stack.append(values) + + def LIST_TO_TUPLE(self, inst: Instruction): + item = self.state.stack.pop() + self.state.stack.append(f"tuple({item})") + + def LIST_EXTEND(self, inst: Instruction): + assert inst.argval == 1, "Only tested for argval==1" + values = self.state.stack.pop() + temp = self.replace_mutable_tos_with_temp() + self.state.source_code += f"{temp}.extend({values})\n" + + def LIST_APPEND(self, inst: Instruction): + if inst.argval == 1: + # it should be a bug, the tos should be the value. fix it anyway. + inst.argval += 1 + container = self.state.stack[-inst.argval] + value = self.state.stack.pop() + self.state.source_code += f"{container}.append({value})\n" + + def generic_update(self, inst: Instruction): + assert inst.argval == 1, "Only tested for argval==1" + values = self.state.stack.pop() + temp = self.replace_mutable_tos_with_temp() + self.state.source_code += f"{temp}.update({values})\n" + + SET_UPDATE = DICT_UPDATE = DICT_MERGE = generic_update + + def SET_ADD(self, inst: Instruction): + if inst.argval == 1: + # it should be a bug, the tos should be the value. fix it anyway. + inst.argval += 1 + container = self.state.stack[-inst.argval] + value = self.state.stack.pop() + self.state.source_code += f"{container}.add({value})\n" + + def MAP_ADD(self, inst: Instruction): + container = self.state.stack[-inst.argval - 1] + # see https://docs.python.org/3.10/library/dis.html#opcode-MAP_ADD + if sys.version_info >= (3, 8): + value = self.state.stack.pop() + key = self.state.stack.pop() + else: + key = self.state.stack.pop() + value = self.state.stack.pop() + self.state.source_code += f"{container}.__setitem__({key}, {value})\n" + +# ==================== Misc Instructions ============================= + def RAISE_VARARGS(self, inst: Instruction): + if inst.argval == 0: + self.state.source_code += "raise\n" + elif inst.argval == 1: + self.state.source_code += f"raise {self.state.stack.pop()}\n" + elif inst.argval == 2: + tos = self.state.stack.pop() + tos1 = self.state.stack.pop() + self.state.source_code += f"raise {tos1} from {tos}\n" + + def FORMAT_VALUE(self, inst: Instruction): + func, spec = inst.argval + if spec: + form_spec = self.state.stack.pop() + value = self.state.stack.pop() + self.state.stack.append(f"format({value}, {form_spec})") + else: + value = self.state.stack.pop() + func = str if func is None else func + self.state.stack.append(f"{func.__name__}({value})") + + + def decompile_range(self, start: int, end: int): + try: + running_index = start + while running_index < end: + inst = self.instructions[running_index] + method = getattr( + Decompiler, + inst.opname, + Decompiler.unimplemented_instruction) + output = method(self, inst) + if output: + running_index = output + else: + running_index += 1 + except Exception as e: + raise DecompilationError( + f"Failed to decompile instruction {inst} in {self.code.co_name}") from e + + def index_of(self, offset: int): + for idx, inst in enumerate(self.instructions): + if inst.offset == offset: + return idx + raise ValueError(f"Cannot find instruction with offset {offset}") + + @staticmethod + def cleanup_instructions(code, instructions: List[Instruction]): + propagate_line_nums(instructions) + simplify_finally_statement(instructions) + nop_unreachable_bytecode(code, instructions) + + def __init__(self, code: Union[CodeType, Callable]): + if callable(code): + from depyf.utils import get_code_owner + code = get_code_owner(code).__code__ + self.code = code + instructions = list(convert_instruction(_) + for _ in dis.get_instructions(code)) + Decompiler.cleanup_instructions(code, instructions) + self.instructions = instructions + self.state = DecompilerState(source_code="", stack=[]) + + def get_temp_name(self): + Decompiler.temp_count += 1 + return f"{self.temp_prefix}{Decompiler.temp_count}" + + def replace_mutable_tos_with_temp(self): + ans = self.state.stack.pop() + temp_name = self.get_temp_name() + self.state.source_code += f"{temp_name} = {ans}\n" + self.state.stack.append(temp_name) + return temp_name + + @staticmethod + def supported_opnames(): + opnames = [] + for x in dis.opname: + if getattr( + Decompiler, + x, + Decompiler.unimplemented_instruction) is not Decompiler.unimplemented_instruction: + opnames.append(x) + return opnames + + @functools.lru_cache(maxsize=None) + def decompile( + self, + indentation=4, + temp_prefix: str = "__temp_", + overwite_fn_name: Optional[str] = None) -> str: + try: + self.indentation = indentation + self.temp_prefix = temp_prefix + self.decompile_range(0, len(self.instructions)) + source_code = self.state.source_code + # the header might have invalid function name in torchdynamo. only + # optimize the function body. + source_code = remove_some_temp( + source_code, self.temp_prefix, indentation) + header = get_function_signature(self.code, overwite_fn_name) + # we cannot rely on `co_names`. For example, `from math import sqrt` will make `math` and `sqrt` in `co_names`. + global_names = set(inst.argval for inst in dis.get_instructions(self.code) if inst.opname == "STORE_GLOBAL") + global_statements = "global " + ", ".join( + global_names) + "\n" if global_names else "" + nonlocal_statement = "nonlocal " + ", ".join( + self.code.co_freevars) + "\n" if self.code.co_freevars else "" + source_code = global_statements + nonlocal_statement + source_code + source_code = header + add_indentation(source_code, indentation) + return source_code + except DecompilationError: + raise + except Exception as e: + raise DecompilationError( + f"Failed to decompile {self.code.co_name}") from e + + @staticmethod + def decompile_and_compile_like( + code_to_decompile: CodeType, + reference_code: CodeType, + indentation=4, + temp_prefix: str = "__temp_", + filepath_template: Optional[str] = None) -> CodeType: + + # first, decompile the code into source code, with function name `__place_holder__` + src = Decompiler(code_to_decompile).decompile(indentation=indentation, temp_prefix=temp_prefix, overwite_fn_name="__place_holder__") + + # fix the freevars/cellvars in the source code + from depyf.code_transform import fix_irregular_code + # check https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/4 for why we need to prepare freevars like `reference_code` rather than `code` + src = fix_irregular_code(reference_code, src) + + if filepath_template is None: + func_name = reference_code.co_name + src = src.replace("__place_holder__", func_name) + filename = "noname" + else: + src_body = src[src.find("("):] + if reference_code.co_freevars: + src_body = src_body[src_body.find("("):] + + count = 0 + while True: + filename = filepath_template % count + if os.path.exists(filename): + existing_code = open(filename, "r").read() + existing_code_body = existing_code[existing_code.find("("):] + if reference_code.co_freevars: + existing_code_body = existing_code_body[existing_code_body.find("("):] + if src_body == existing_code_body: + # the same code body is found, we do not need to dump the code again. + src = existing_code + break + else: + count += 1 + else: + func_name = filename.split(os.path.sep)[-1].split(".")[0] + src = src.replace("__place_holder__", func_name) + with open(filename, "w") as f: + f.write(src) + break + + func_name = filename.split(os.path.sep)[-1].split(".")[0] + + from depyf.utils import collect_all_code_objects + transformed_code = compile(src, filename=filename, mode="exec") + transformed_codes = collect_all_code_objects(transformed_code) + decompiled_and_compiled_back_code = [x for x in transformed_codes if x.co_name == func_name][0] + + # torch.compile might hold random non-constant values in `new_code.co_consts` that cannot + # be represented in source code. During decompliation, we treat them as `__co_consts[i]`, + # a string that represents the constant in the original code object. + # We need to replace them with the actual constant in the original code object, so that + # the decompiled and compiled back code object can be used for execution. + updated_consts = [] + for i, x in enumerate(decompiled_and_compiled_back_code.co_consts): + if isinstance(x, str) and x.startswith("__co_consts"): + index = int(x.split("[")[-1][:-1]) # __co_consts[0] -> 0 + updated_consts.append(code_to_decompile.co_consts[index]) + else: + updated_consts.append(x) + + decompiled_and_compiled_back_code = decompiled_and_compiled_back_code.replace(co_consts=tuple(updated_consts)) + + return decompiled_and_compiled_back_code + + def __hash__(self): + # see https://github.com/thuml/depyf/pull/21 + return id(self.code) + + def __eq__(self, other): + return hash(self) == hash(other) + +def decompile(code: Union[CodeType, Callable]) -> str: + """Decompile any callable or code object into Python source code. + It is especially useful for some dynamically generated code, like ``torch.compile``, + or ``dataclasses``. + + Example usage: + + .. code-block:: python + + from dataclasses import dataclass + @dataclass + class Data: + x: int + y: float + + import depyf + print(depyf.decompile(Data.__init__)) + print(depyf.decompile(Data.__eq__)) + + Output: + + .. code-block:: python + + def __init__(self, x, y): + self.x = x + self.y = y + return None + + def __eq__(self, other): + if other.__class__ is self.__class__: + return (self.x, self.y) == (other.x, other.y) + return NotImplemented + + The output source code is semantically equivalent to the function, but not syntactically the same. It verbosely adds many details that are hidden in the Python code. For example, the above output code of ``__init__`` explicitly returns ``None``, which is typically ignored. + + Another detail is that the output code of ``__eq__`` returns ``NotImplemented`` instead of raising ``NotImplemented`` exception when the types are different. At the first glance, it seems to be a bug. However, it is actually the correct behavior. The ``__eq__`` method should return ``NotImplemented`` when the types are different, so that the other object can try to compare with the current object. See `the Python documentation `_ for more details. + """ + return Decompiler(code).decompile() diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__init__.py b/.venv/lib/python3.11/site-packages/depyf/explain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31feb23a2b5235b1c86669dd716914578c778867 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/__init__.py @@ -0,0 +1,17 @@ +from depyf.explain.utils import DynamoOptimizationResult + +from torch._dynamo.eval_frame import innermost_fn + +from typing import List, Callable, Dict, Union, Set +from types import CodeType + + +def _extract_artifacts(original_code: CodeType, module): + result = DynamoOptimizationResult(original_code, None, module) + return result + +def dump_src(original_code: CodeType, module): + from depyf.explain.global_variables import data + assert data["is_inside_prepare_debug"], "`dump_src` must be used inside `depyf.prepare_debug`." + artifacts = _extract_artifacts(original_code, module) + return artifacts.to_src() diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89dbe2b1b1baae31d3f3f6bddda6909428a92f71 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enable_debugging.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enable_debugging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5926436abc6e32ed00feb454f9c3145007b4c97e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enable_debugging.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enhance_logging.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enhance_logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ef196d9424546e5f68f544696b28a4ada83df1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/enhance_logging.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/global_variables.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/global_variables.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b6ed164d2da43fcd58775bdce49751cab00c189 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/global_variables.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched___call__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched___call__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..632b2e0023d0c5644242658972080b5dccb61eac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched___call__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched__exec_with_source.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched__exec_with_source.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d40548459397f8de9eeeccb02e26f15d7ad0fc7c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched__exec_with_source.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_boxed_run.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_boxed_run.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7ca279cb2348267413d88fcd3e1eff48d593d57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_boxed_run.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_lazy_format_graph_code.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_lazy_format_graph_code.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea292e530cc972339e6b0397a7fb43b5f0d8892 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_lazy_format_graph_code.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_load_by_key_path.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_load_by_key_path.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1bd273a7fe56e246bce80933b54e9c041f67e32 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/patched_load_by_key_path.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5a1ed8bd71e6acdf451b69bb3f163483c2be9ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/depyf/explain/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/enable_debugging.py b/.venv/lib/python3.11/site-packages/depyf/explain/enable_debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..61179dfecf87a63c13d101c9ab5d091c55050837 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/enable_debugging.py @@ -0,0 +1,251 @@ +from .patched_boxed_run import patched_boxed_run +from .patched_lazy_format_graph_code import patched_lazy_format_graph_code +from .patched_load_by_key_path import patched_load_by_key_path +from .patched__exec_with_source import patched__exec_with_source +from typing import List, Tuple, Dict, Union, Callable, Optional, Any + +import contextlib +import warnings +import traceback + +import dataclasses +import itertools +import sys +import os +import inspect + + +@dataclasses.dataclass +class DebuggableHook(object): + dump_src_dir: str + log_bytecode: bool + optimized_code_and_module: List =dataclasses.field(default_factory=list, init=False) + + def __call__(self, code, new_code): + frame = sys._getframe() + import os + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == code + self.optimized_code_and_module.append([code, frame.f_globals]) + from depyf.decompiler import DecompilationError + try: + import os + # replace " "/"<"/">"/"." with "_" + func_name = code.co_name.replace(".", "_").replace("<", "_").replace(">", "_").replace(" ", "_") + filepath_template = os.path.join( + self.dump_src_dir, + f"__transformed_code_%s_for_{func_name}.py") + + from depyf.explain.utils import lock_on_file + from depyf.decompiler import Decompiler + + # function name and file name are related. + with lock_on_file(filepath_template): + decompiled_and_compiled_back_code = Decompiler.decompile_and_compile_like(code_to_decompile=new_code, reference_code=code, filepath_template=filepath_template) + filename = decompiled_and_compiled_back_code.co_filename + if self.log_bytecode: + with lock_on_file(filename): + import dill + # code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`. + # see https://github.com/pytorch/pytorch/issues/116013 for more details. + with contextlib.suppress(Exception): + dill.dump(code, open(filename + ".original_bytecode", "wb")) + + with contextlib.suppress(Exception): + dill.dump(new_code, open(filename + ".transformed_bytecode", "wb")) + + with contextlib.suppress(Exception): + dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb")) + + # this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487 + from torch._dynamo.utils import orig_code_map + from torch._dynamo.convert_frame import output_codes + output_codes.add(decompiled_and_compiled_back_code) + orig_code_map[decompiled_and_compiled_back_code] = code + + return decompiled_and_compiled_back_code + except (DecompilationError, SyntaxError) as e: + from io import StringIO + string_io = StringIO() + import dis + print("There is a problem when decompiling and compiling the following code:", file=string_io) + dis.dis(new_code, file=string_io) + print("Please consider submitting an issue to https://github.com/thuml/depyf .", file=string_io) + # do not stop the program for decompilation error and compile error + warnings.warn(string_io.getvalue()) + traceback.print_exc() + +@contextlib.contextmanager +def patch(parent, name, value): + old_value = getattr(parent, name, None) + if old_value is not None: + setattr(parent, name, value) + try: + yield + finally: + if old_value is not None: + setattr(parent, name, old_value) + + +@contextlib.contextmanager +def enable_bytecode_hook(hook): + import torch + handle = torch._dynamo.convert_frame.register_bytecode_hook(hook) + try: + yield + finally: + handle.remove() + + +@contextlib.contextmanager +def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False): + """ + A context manager to dump debugging information for torch.compile. + It should wrap the code that actually triggers the compilation, rather than + the code that applies ``torch.compile``. + + Example: + + .. code-block:: python + + import torch + + @torch.compile + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + + def main(): + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) + + if __name__ == "__main__": + # main() + # surround the code you want to run inside `with depyf.prepare_debug` + import depyf + with depyf.prepare_debug("./dump_src_dir"): + main() + + After running the code, you will find the dumped information in the directory ``dump_src_dir``. The details are organized into the following: + + - ``full_code_for_xxx.py`` for each function using torch.compile + - ``__transformed_code_for_xxx.py`` for Python code associated with each graph. + - ``__transformed_code_for_xxx.py.xxx_bytecode`` for Python bytecode, dumped code object, can be loaded via ``dill.load(open("/path/to/file", "wb"))``. Note that the load function might import some modules like transformers. Make sure you have these modules installed. + - ``__compiled_fn_xxx.py`` for each computation graph and its optimization: + - ``Captured Graph``: a plain forward computation graph + - ``Joint Graph``: joint forward-backward graph from AOTAutograd + - ``Forward Graph``: forward graph from AOTAutograd + - ``Backward Graph``: backward graph from AOTAutograd + - ``kernel xxx``: compiled CPU/GPU kernel wrapper from Inductor. + + Arguments: + + - ``dump_src_dir``: the directory to dump the source code. + - ``clean_wild_fx_code``: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally. + - ``log_bytecode``: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code). + """ + + if not isinstance(dump_src_dir, str): + raise RuntimeError('''You are using an obsolete usage style`depyf.prepare_debug(func=function, dump_src_dir="/path")`. Please use `depyf.prepare_debug(dump_src_dir="/path")` instead, which will automatically capture all compiled functions.''') + + import os + import torch + + current_line_number = inspect.currentframe().f_lineno + 1 + warnings.warn_explicit(f"{__file__}:{current_line_number}: You are trying to debug `torch.compile`. Please make sure the code runs multiple times to cover all the possible branches.", UserWarning, "", 0) + + from depyf.utils import safe_create_directory + + if not os.path.exists(dump_src_dir): + safe_create_directory(dump_src_dir) + + dump_src_dir = os.path.abspath(dump_src_dir) + + from .global_variables import data + + data["dump_src_dir"] = dump_src_dir + data["unpatched__exec_with_source"] = torch.fx.graph_module._exec_with_source + data["unpatched_load_by_key_path"] = torch._inductor.codecache.PyCodeCache.load_by_key_path + data["unpatched___call__"] = torch._dynamo.eval_frame.OptimizeContext.__call__ + data["is_inside_prepare_debug"] = True + + bytecode_hook = DebuggableHook(dump_src_dir, log_bytecode) + + # patch some functions + with patch(torch.fx.graph_module, "_exec_with_source", patched__exec_with_source), \ + patch(torch._inductor.codecache.PyCodeCache, "load_by_key_path", patched_load_by_key_path), \ + patch(torch._dynamo.utils.lazy_format_graph_code, "__code__", patched_lazy_format_graph_code.__code__): + # we have to directly manipulate the code object, since the function has been imported in many places. + # simply replacing torch._dynamo.utils.lazy_format_graph_code does not work for those functions. + # Note: `unitest.mock.patch` does not work here, since it will not + # patch the code object. (it will try to delete the code object and + # then set a new code object. The `delattr` will raise an error.) + + # enable bytecode hook + with enable_bytecode_hook(bytecode_hook): + try: + yield + finally: + + code_names = {x[0].co_name for x in bytecode_hook.optimized_code_and_module} + for code, module in bytecode_hook.optimized_code_and_module: + if code.co_name.startswith("resume_in_") and any(f"resume_in_{name}" in code.co_name for name in code_names): + continue + # https://github.com/pytorch/pytorch/pull/118201 introduces `torch_dynamo_resume_in_` names. + if code.co_name.startswith("torch_dynamo_resume_in_") and any(f"torch_dynamo_resume_in_{name}" in code.co_name for name in code_names): + continue + from depyf.explain import dump_src + from depyf.explain.utils import write_code_to_file_template + from torch._dynamo.eval_frame import innermost_fn, _debug_get_cache_entry_list + entries = _debug_get_cache_entry_list(code) + if not entries: + current_line_number = inspect.currentframe().f_lineno + 1 + warnings.warn_explicit(f"{__file__}:{current_line_number}: Code object {code} is compiled but does not have any compiled cache entries. Probably some torch.nn.Module instances are destroyed too early. It is recommended to make sure the torch.nn.Module instances exist after `with depyf.prepare_debug`.", UserWarning, "", 0) + full_src = dump_src(code, module) + filepath_template = os.path.join(dump_src_dir, f"full_code_for_{code.co_name}_%s.py") + full_code_path = write_code_to_file_template(full_src, filepath_template) + + for file in os.listdir(dump_src_dir): + name = file.split(os.path.sep)[-1] + # remove *.lock file and possibly fx_graph_code* file + if (clean_wild_fx_code and name.startswith("fx_graph_code")) or name.endswith(".lock"): + try: + # multiple processes may try to remove the same file. + os.remove(os.path.join(dump_src_dir, file)) + except OSError: + pass + + data["is_inside_prepare_debug"] = False + +@contextlib.contextmanager +def debug(): + """ + A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix ``full_code_for_`` in the ``dump_src_dir`` argument of :func:`depyf.prepare_debug`, and set breakpoints in their separate ``__transformed_code_`` files according to the function name. Then continue your debugging. + """ + from .global_variables import data + if data["is_inside_prepare_debug"]: + raise RuntimeError("You cannot use `depyf.debug` inside `depyf.prepare_debug`.") + dump_src_dir = data["dump_src_dir"] + import torch + # after https://github.com/pytorch/pytorch/pull/131258 + # torch._dynamo.eval_frame.set_eval_frame is not available in the module + # we need to directly access it from the `_C` extension. + callback = torch._C._dynamo.eval_frame.set_eval_frame(False) + # sometimes pytorch use Interpreter to run node by node. This cannot be debugged. + # we patch this function to run the graph function directly. + with patch(torch.fx.Interpreter.boxed_run, "__code__", patched_boxed_run.__code__): + try: + msg = f"`depyf` places a breakpoint here to pause the program. You can check the full source code in files with prefix `full_code_for_` in {dump_src_dir} first, and set breakpoints in their separate files according to the function name. Then continue your debugging." + print(msg) + breakpoint() + yield + finally: + torch._C._dynamo.eval_frame.set_eval_frame(callback) diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/enhance_logging.py b/.venv/lib/python3.11/site-packages/depyf/explain/enhance_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad42c0613e62fc21ece1cf39bee120293683b78 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/enhance_logging.py @@ -0,0 +1,94 @@ +import types +from depyf.decompiler import decompile, DecompilationError + + +def pytorch_bytecode_src_hook(code: types.CodeType, new_code: types.CodeType): + import torch + bytecode_log = torch._logging.getArtifactLogger( + "torch._dynamo.convert_frame", "bytecode" + ) + import logging + + if bytecode_log.isEnabledFor(logging.DEBUG): + try: + decompiled_src = decompile(new_code) + bytecode_log.debug("possible source code:") + bytecode_log.debug(decompiled_src) + except DecompilationError as e: + bytecode_log.debug("Decompilation fails due to: %s", str(e)) + finally: + bytecode_log.debug( + "If you find the decompiled code is wrong," + "please submit an issue at " + "https://github.com/thuml/depyf/issues." + ) + + +_handle = None + + +def install(): + """ + Install the bytecode hook for PyTorch, integrate into PyTorch's logging system. + + Example: + + .. code-block:: python + + import torch + import depyf + depyf.install() + # anything with torch.compile + @torch.compile + def f(a, b): + return a + b + f(torch.tensor(1), torch.tensor(2)) + + Turn on bytecode log by ``export TORCH_LOGS="+bytecode"``, and execute the script. + We will see the decompiled source code in the log: + + .. code-block:: text + + ORIGINAL BYTECODE f test.py line 5 + 7 0 LOAD_FAST 0 (a) + 2 LOAD_FAST 1 (b) + 4 BINARY_ADD + 6 RETURN_VALUE + + + MODIFIED BYTECODE f test.py line 5 + 5 0 LOAD_GLOBAL 0 (__compiled_fn_1) + 2 LOAD_FAST 0 (a) + 4 LOAD_FAST 1 (b) + 6 CALL_FUNCTION 2 + 8 UNPACK_SEQUENCE 1 + 10 RETURN_VALUE + + + possible source code: + def f(a, b): + __temp_2, = __compiled_fn_1(a, b) + return __temp_2 + + If you find the decompiled code is wrong,please submit an issue at https://github.com/thuml/depyf/issues. + + To uninstall the hook, use :func:`depyf.uninstall()`. + """ + import torch + global _handle + if _handle is not None: + return + _handle = torch._dynamo.convert_frame.register_bytecode_hook( + pytorch_bytecode_src_hook) + + +def uninstall(): + """ + Uninstall the bytecode hook for PyTorch. + Should be called after :func:`depyf.install()`. + """ + global _handle + if _handle is None: + return + _handle.remove() + _handle = None diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/global_variables.py b/.venv/lib/python3.11/site-packages/depyf/explain/global_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..884186c084d525c76914bf22b21c199a3db5ef2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/global_variables.py @@ -0,0 +1,14 @@ +import os + +import torch + +from torch._inductor.codecache import PyCodeCache + +data = { + "dump_src_dir": os.path.join(os.path.dirname(__file__), "dumped_src"), + "unpatched__exec_with_source": torch.fx.graph_module._exec_with_source, + "unpatched_load_by_key_path": PyCodeCache.load_by_key_path, + "unpatched___call__": torch._dynamo.eval_frame.OptimizeContext.__call__, + "optimized_functions": set(), + "is_inside_prepare_debug": False, +} diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/patched___call__.py b/.venv/lib/python3.11/site-packages/depyf/explain/patched___call__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d033c774f8d8cb18e177ef07d4f965d097aa28 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/patched___call__.py @@ -0,0 +1,9 @@ +def patched___call__(self, code, check_fn): + from depyf.explain.global_variables import data + from depyf.utils import get_code_owner + import torch + unpatched___call__ = data["unpatched___call__"] + optimized_functions = data["optimized_functions"] + optimized_functions.add(code) + + return unpatched___call__(self, code, check_fn) \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/patched__exec_with_source.py b/.venv/lib/python3.11/site-packages/depyf/explain/patched__exec_with_source.py new file mode 100644 index 0000000000000000000000000000000000000000..7982f802e4326db8345ccf25969d5b584afd8f1d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/patched__exec_with_source.py @@ -0,0 +1,20 @@ +def patched__exec_with_source(src: str, globals, co_fields=None): + from depyf.explain.global_variables import data + from depyf.explain.utils import write_code_to_file_template + dump_src_dir = data["dump_src_dir"] + unpatched__exec_with_source = data["unpatched__exec_with_source"] + unpatched__exec_with_source(src, globals, co_fields) + import inspect + key = inspect.getsourcefile(globals["forward"]) + import hashlib + import os + hash_value = hashlib.md5(src.encode()).hexdigest() + src = "# " + key + src + filename = write_code_to_file_template( + src, + f"{dump_src_dir}/fx_graph_code_" + + hash_value + + "_" + + "%s" + + ".py") + exec(compile(src, filename, "exec"), globals) diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/patched_boxed_run.py b/.venv/lib/python3.11/site-packages/depyf/explain/patched_boxed_run.py new file mode 100644 index 0000000000000000000000000000000000000000..0fcbafe4ffdc132d40d1a8aca4bd246ba5bac412 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/patched_boxed_run.py @@ -0,0 +1,2 @@ +def patched_boxed_run(self, args_list): + return self.module.forward(*args_list) diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/patched_lazy_format_graph_code.py b/.venv/lib/python3.11/site-packages/depyf/explain/patched_lazy_format_graph_code.py new file mode 100644 index 0000000000000000000000000000000000000000..28217d9227162618ed65314872a420cca3a2d5e3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/patched_lazy_format_graph_code.py @@ -0,0 +1,78 @@ +def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): + from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template + from depyf.utils import get_code_owner + # When using torch export, the name includes + # a dumped dict of the nn_module_stack of a node in the module after the ':' + if ':' in name: + name = name.split(':')[0] + func_name = get_current_compiled_fn_name() + file_name = name if name != func_name else "Captured Graph" + file_name = file_name.replace(" ", "_") + file_name = func_name + "." + file_name + import inspect + import os + + # https://github.com/pytorch/pytorch/pull/117911 introduces LazyGraphModule + # whose `forward` method is mangled and cannot be manipulated. + # We need to get rid of the laziness by calling `str` on it. + gm_s = str(gm) + + fn = gm.forward + + fn = get_code_owner(fn) + + # update file path + filepath = inspect.getsourcefile(fn) + # try to use verbose code with type and shape annotations + use_gm = True + + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") + try: + compile(src, "noname", "exec") + except Exception as e: + # the pytorch version is before this PR: https://github.com/pytorch/pytorch/pull/113345 + # Verbose code contains syntax error, it is recommended to use new + # version of PyTorch to get runnable code with shape and type + # annotations. + simple_code = gm._graph.python_code(root_module="self", verbose=False).src + commented_src = "\n# code below is commented out due to syntax error. You can refer to the code for shape and dtype annotation.\n" + commented_src += "".join(["# " + line + + "\n" for line in src.splitlines()]) + src = simple_code + commented_src + use_gm = False + if filepath is not None: + new_filepath = write_code_to_file_template( + src, os.path.dirname(filepath) + "/" + file_name + "." + "%s" + ".py") + scope = fn.__globals__ + exec(compile(src, filename=new_filepath, mode="exec"), scope) + if use_gm: + import torch + classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)] + assert len(classes) == 1 + module_class = classes[0] + fn.__code__ = getattr(module_class, fn.__name__).__code__ + else: + fn.__code__ = scope[fn.__name__].__code__ + del scope[fn.__name__] + + # ========================================= + # original code of `lazy_format_graph_code` + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + if "print_output" not in kwargs: + kwargs["print_output"] = False + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(**kwargs), + ) + ) diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/patched_load_by_key_path.py b/.venv/lib/python3.11/site-packages/depyf/explain/patched_load_by_key_path.py new file mode 100644 index 0000000000000000000000000000000000000000..14fdaaf4286de09af9321182a4aac881cde2aea6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/patched_load_by_key_path.py @@ -0,0 +1,21 @@ +def patched_load_by_key_path( + key: str, + path: str, + linemap, + attrs, +): + from depyf.explain.global_variables import data + from depyf.explain.utils import write_code_to_file_template, get_current_compiled_fn_name + dump_src_dir = data["dump_src_dir"] + unpatched_load_by_key_path = data["unpatched_load_by_key_path"] + import os + # hack the path to our dump_src_dir + src = open(path).read() + # do not remove. remove in multi-processes will cause error. + # os.remove(path) + + func_name = get_current_compiled_fn_name() + new_filepath = write_code_to_file_template(src, os.path.join( + dump_src_dir, func_name + ".kernel_" + "%s" + ".py")) + path = new_filepath + return unpatched_load_by_key_path(key, path, linemap, attrs) diff --git a/.venv/lib/python3.11/site-packages/depyf/explain/utils.py b/.venv/lib/python3.11/site-packages/depyf/explain/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe67747d4fcde53735a1c4a8c06c9dbb792775a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/explain/utils.py @@ -0,0 +1,338 @@ +import torch +from torch._dynamo.eval_frame import innermost_fn +from torch._dynamo.eval_frame import _debug_get_cache_entry_list +import inspect + +import dis +from types import CodeType +from typing import List, Callable, Dict, Union, Set +from dataclasses import dataclass +import contextlib + +class CodeProxy: + instances: Dict[str, "CodeProxy"] = {} + used_instances: Set[str] = set() + + @staticmethod + def get_new_name(name: str): + i = 0 + new_name = name + if new_name.endswith(":"): + name = name[:-1] + while True: + new_name = f"{name}_{i}" + if new_name not in CodeProxy.instances: + break + i += 1 + return new_name + + @staticmethod + def consume_new_name(name: str): + new_name = CodeProxy.get_new_name(name) + CodeProxy.instances[new_name] = None + return new_name + + @staticmethod + def decompile_with_name(code: CodeType, name: str, skip_decompile=False): + from depyf.utils import decompile_ensure + if hasattr(code, "__code__"): + code = code.__code__ + if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"): + src = open(code.co_filename).read() + new_name = code.co_name + else: + new_name = CodeProxy.get_new_name(name) + if not skip_decompile: + src = decompile_ensure(code, new_name) + else: + src = "" + self = CodeProxy(src) + self.name = new_name + self.code = f"""
+ {self.name} + + ```python +{self.raw_code} + ``` +
+""" + CodeProxy.instances[self.name] = self + return self + + def __init__(self, code: str): + # Don't directly use this constructor. Use decompile_with_name instead. + self.raw_code = "".join( + [" " + line + "\n" for line in code.splitlines() if line.strip() != ""]) + + def __str__(self): + CodeProxy.used_instances.add(self.name) + return self.name + + @contextlib.contextmanager + @staticmethod + def record(): + CodeProxy.used_instances = set() + yield CodeProxy.used_instances + + +@dataclass +class CacheResult: + original_code: CodeType + transformed_code: CodeType + guard: List[str] + compiled_subgraph: Callable + compiled_subgraph_proxy: CodeProxy + transformed_code_proxy: CodeProxy + referenced_global_functions: Dict[str, "DynamoOptimizationResult"] + + def __init__(self, original_code, module, cache): + self.original_code = original_code + + cpp_guard = False + + # starting from https://github.com/pytorch/pytorch/pull/138896 , + # pytorch uses `guard_manager` instead of `check_fn` to store the + # guards + attr_name = "guard_manager" if hasattr(cache, "guard_manager") else "check_fn" + + guard_manager = getattr(cache, attr_name) + + try: + klass = getattr(torch._dynamo.guards, "GuardManagerWrapper", None) or \ + getattr(torch._dynamo.guards, "GuardManager", None) or \ + getattr(torch._C._dynamo.guards, "GuardManager", None) + assert klass is not None + cpp_guard = isinstance(guard_manager, klass) + except Exception: + pass + + if not cpp_guard: + # for old version of pytorch, + # `guard_manager` is a plain python function + guard_codes = guard_manager.code_parts + freevar_names = guard_manager.__code__.co_freevars + freevar_values = [x.cell_contents for x in guard_manager.__closure__] + else: + # keep the logic synced with + # https://github.com/pytorch/pytorch/blob/7b6b10417d8616ebd7a42b06528c5c2b2fded55a/torch/_dynamo/guards.py#L262 + tensor_aliasing_guard_seen = False + def visit(root, ans): + nonlocal tensor_aliasing_guard_seen + for leaf_guard in root.get_leaf_guards(): + if isinstance(leaf_guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): + if not tensor_aliasing_guard_seen: + tensor_aliasing_guard_seen = True + else: + continue + append_guard_code(leaf_guard, ans) + for child in root.get_child_managers(): + visit(child, ans) + guard_codes = [] + root = guard_manager.root + + # Add guards in RootGuardManager + visit(root, guard_codes) + # Add guards in epilogue lambda guards + if hasattr(root, "get_epilogue_lambda_guards"): + for lambda_guard in root.get_epilogue_lambda_guards(): + append_guard_code(lambda_guard, guard_codes) + + if guard_manager.closure_vars is None: + freevar_names = tuple() + freevar_values = [] + else: + freevar_names = tuple(guard_manager.closure_vars.keys()) + freevar_values = list(guard_manager.closure_vars.values()) + + self.guard = guard_codes + self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)} + code = cache.code + + compiled_subgraphs = [ + name for name in code.co_names if name.startswith("__compiled")] + assert len(compiled_subgraphs) <= 1 + + if compiled_subgraphs: + # deal with compiled_subgraph + self.compiled_subgraph = innermost_fn(module[compiled_subgraphs[0]]) + # subgraph does not need decompile + self.compiled_subgraph_proxy = CodeProxy.decompile_with_name( + self.compiled_subgraph, compiled_subgraphs[0], skip_decompile=True) + else: + self.compiled_subgraph = None + self.compiled_subgraph_proxy = None + # deal with transformed_code + self.transformed_code = code + self.transformed_code_proxy = CodeProxy.decompile_with_name( + self.transformed_code, "transformed_code:") + resume_fns = [ + name for name in code.co_names if name.startswith("__resume")] + self.referenced_global_functions = {} + for name in resume_fns: + self.referenced_global_functions[name] = DynamoOptimizationResult( + original_code=module[name].__code__, + function_name=name, + module=module) + + def to_data(self): + return { + "guard": self.guard, + "transformed_code": str( + self.transformed_code_proxy), + "compiled_subgraph": str( + self.compiled_subgraph_proxy) if self.compiled_subgraph_proxy is not None else '"No compiled subgraph."', + "referenced_global_functions": { + name: fn.to_data() for name, + fn in self.referenced_global_functions.items()}} + + +@dataclass +class DynamoOptimizationResult: + function_name: str + module: dict + original_code: CodeType + source_code_proxy: CodeProxy + transformed_code_entries: List[CacheResult] + + def __init__(self, original_code, function_name=None, module=None): + self.original_code = original_code + if function_name is None: + self.function_name = original_code.co_name + else: + self.function_name = function_name + self.module = module + caches = _debug_get_cache_entry_list(original_code) + self.transformed_code_entries = [ + CacheResult(original_code, module, cache) for cache in caches] + self.source_code_proxy = CodeProxy.decompile_with_name( + self.original_code, self.function_name) + + def to_data(self): + data = { + "function_name": self.function_name, + "source_code": str( + self.source_code_proxy), + "transformed_code_entries": [ + entry.to_data() for entry in self.transformed_code_entries]} + return data + + def to_src(self): + raw_code = self.source_code_proxy.raw_code + + # prepare function signature, from `def toy_example(a, b)` to `def + # transformed_toy_example(a, b)` + signature = raw_code.splitlines()[0].replace( + "def ", "def transformed_", 1) + code = signature.strip() + + # prepare args for guards, like `L = {"a": a, "b": b}` + code_obj = self.original_code + normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount + arg_names = code_obj.co_varnames[:normal_arg_count] + arg_dict = "__local_dict = {" + \ + ", ".join([f'"{name}": {name}' for name in arg_names]) + "}" + code += "\n" + " " * 4 + arg_dict + code += "\n" + " " * 4 + "__global_dict = globals()" + + additional_code = "" + + for entry in self.transformed_code_entries: + + # prepare guards, like `def guard_0(L):\n return a > 0 and b > + # 0` + freevars = "".join([f"{name} = '''{value}'''\n" for name, value in entry.freevars.items() if name not in ["__builtins__"]]) + if freevars: + freevars = "# Note: the following variables are used inside the guard function.\n" + freevars + guard_lines = [" " * 4 + "__guard_hit = True\n"] + for x in entry.guard: + guard_lines.append(" " * 4 + f"__guard_hit = __guard_hit and {x}\n") + guard_lines.append(" " * 4 + "return __guard_hit\n") + guard = "".join(guard_lines) + if entry.transformed_code_proxy.name.startswith("__transformed_code_"): + guard_func_name = entry.transformed_code_proxy.name.replace("__transformed_code_", "__guard_") + else: + guard_func_name = CodeProxy.consume_new_name("guard:") + additional_code += "\n" + freevars + f"def {guard_func_name}(L, G, **___kwargs_ignored):\n" + guard + + if entry.compiled_subgraph_proxy is not None: + # prepare compiled subgraph, like `__compiled_fn_0` + subgraph_name = entry.compiled_subgraph_proxy.name + additional_code += "\n" + additional_code += f"# Note: please refer to the graph code in {subgraph_name}*.py.\n" + additional_code += f"# Captured Graph: Dynamo generated graph (debuggable when using eager backend).\n" + additional_code += f"# Joint graph: joint forward+backward graph from aot autograd.\n" + additional_code += f"# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).\n" + additional_code += f"# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).\n" + additional_code += f"# AFTER XXX: graph processed by inductor (not debuggable).\n" + additional_code += f"def {subgraph_name}(*args, **kwargs):\n pass\n" + + # prepare transformed code, like `transformed_code_0` + additional_code += "\n" + \ + remove_indentation(entry.transformed_code_proxy.raw_code) + "\n" + + for name, func in entry.referenced_global_functions.items(): + additional_code = func.to_src() + additional_code + + code += "\n" + " " * 4 + \ + f"if {guard_func_name}(__local_dict, __global_dict):\n" + " " * 8 + f"return {entry.transformed_code_proxy.name}({', '.join(arg_names)})" + + additional_code += "\n" + "# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.\n" + \ + remove_indentation(self.source_code_proxy.raw_code) + "\n" + + code += "\n" + " " * 4 + "# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.\n" + \ + " " * 4 + f"return {self.source_code_proxy.name}({', '.join(arg_names)})" + return additional_code + code + \ + f"\n\n#============ end of {self.function_name} ============#\n" + + +def remove_indentation(code: str): + lines = code.splitlines() + indent = len(lines[0]) - len(lines[0].lstrip()) + return "".join([line[indent:] + "\n" for line in lines]) + +def append_guard_code(guard, ans): + for verbose_str in guard.verbose_code_parts(): + verbose_str = verbose_str.strip() + ans.append(verbose_str) + +from contextlib import contextmanager + +@contextmanager +def lock_on_file(path_template): + lock_path = path_template + ".lock" + from filelock import FileLock + import os + lock = FileLock(lock_path) + try: + with lock: + yield + finally: + pass + + +def write_code_to_file_template(src, path_template): + with lock_on_file(path_template): + import os + count = 0 + while True: + new_filepath = path_template % str(count) + if not os.path.exists(new_filepath): + with open(new_filepath, "w") as f: + f.write(src) + break + # might be a hash collision + existing_code = open(new_filepath).read() + if existing_code == src: + break + count += 1 + return new_filepath + + +def get_current_compiled_fn_name(): + import torch + from torch._dynamo.bytecode_transformation import _unique_id_counter + from copy import copy + # torch.compile already called the next, we should add minus 1 to get the + # correct name + current_count = next(copy(_unique_id_counter)) - 1 + return "__compiled_fn_" + str(current_count) diff --git a/.venv/lib/python3.11/site-packages/depyf/optimization.py b/.venv/lib/python3.11/site-packages/depyf/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bdaf18aeee255439d784ecaf56b5bc688e4293 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/optimization.py @@ -0,0 +1,74 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List + +import torch + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable, use_custom_dispatcher: bool = True): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + self.use_custom_dispatcher: bool = use_custom_dispatcher + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/.venv/lib/python3.11/site-packages/depyf/utils.py b/.venv/lib/python3.11/site-packages/depyf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92a224f50dc458b7b1ecef7e978a670fd0e2cf07 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/depyf/utils.py @@ -0,0 +1,90 @@ +import dis +from typing import List, Tuple, Union, Optional, Callable, Any, Dict, Set +from types import CodeType + + +def get_function_signature(code_obj: CodeType, + overwite_fn_name: Optional[str] = None) -> str: + # Extract all required details from the code object + # Sometimes the code object does not have a name, e.g. when it is a lambda + # function, so we can overwrite it to be a valid name + normal_arg_count = code_obj.co_argcount + code_obj.co_kwonlyargcount + arg_names = code_obj.co_varnames[:normal_arg_count] + arg_names = [ + x if not x.startswith(".") else x.replace( + ".", "comp_arg_") for x in arg_names] + + import inspect + if code_obj.co_flags & inspect.CO_VARARGS: + arg_names.append('*' + code_obj.co_varnames[normal_arg_count]) + normal_arg_count += 1 + if code_obj.co_flags & inspect.CO_VARKEYWORDS: + arg_names.append('**' + code_obj.co_varnames[normal_arg_count]) + normal_arg_count += 1 + args_str = ', '.join(arg_names) + fn_name = overwite_fn_name if overwite_fn_name is not None else code_obj.co_name + header = f"def {fn_name}({args_str}):\n" + return header + + +def collect_all_code_objects(code: CodeType) -> List[CodeType]: + code_objects = [code] + for const in code.co_consts: + if isinstance(const, type(code)): + code_objects.extend(collect_all_code_objects(const)) + return code_objects + + +def safe_create_directory(path): + # allow multiple processes to create the same directory + import os + try: + os.makedirs(path, exist_ok=True) + except OSError as e: + if not os.path.isdir(path): + raise + + + +def get_code_owner(fn): + """A callable object `fn` might have a __code__ attribute, which is a code object. + However, `fn` might not be the owner of the code object. Only the code owner can change the code object. + This function returns the owner of the code object. + An example: + class A: + def func(self): + return 1 + a = A() + `a.func.__code__` is read-only. `A.func.__code__` is writable. + We can change the code object via `a.func.__func__.__code__`. + """ + import functools + while True: + if hasattr(fn, "__func__"): + # deal with bounded function + fn = fn.__func__ + elif hasattr(fn, "__wrapped__"): + # deal with lru_cache or other decorators + fn = fn.__wrapped__ + elif isinstance(fn, functools.partial): + # deal with partial function + fn = fn.func + elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"): + # deal with callable object + fn = fn.__call__.__func__ + else: + break + return fn + + + +def decompile_ensure(fn: CodeType, overwite_fn_name=None): + import depyf + from depyf.decompiler import DecompilationError + try: + decompiled_source_code = depyf.Decompiler( + fn).decompile(overwite_fn_name=overwite_fn_name) + except DecompilationError as e: + header = get_function_signature(fn, overwite_fn_name=overwite_fn_name) + decompiled_source_code = header + " 'Failed to decompile.'\n" + return decompiled_source_code diff --git a/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..cdd68a497cdfa8d3f2b837225beacef711b85047 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/platformdirs-4.3.6.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.25.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/LICENSE b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..38438c121a8731fcf91c7cb6cb268baccf24fc4c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014-2022 Matthew Brennan Jones + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/METADATA b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..3f2fd71bbe71d450630269664cdb269151ad9b1a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/METADATA @@ -0,0 +1,27 @@ +Metadata-Version: 2.1 +Name: py-cpuinfo +Version: 9.0.0 +Summary: Get CPU info with pure Python +Home-page: https://github.com/workhorsy/py-cpuinfo +Author: Matthew Brennan Jones +Author-email: matthew.brennan.jones@gmail.com +License: MIT +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Topic :: Utilities +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +License-File: LICENSE + +py-cpuinfo +========== + + +Py-cpuinfo gets CPU info with pure Python. Py-cpuinfo should work +without any extra programs or libraries, beyond what your OS provides. +It does not require any compilation(C/C++, assembly, et cetera) to use. +It works with Python 3. + +Documentation can be viewed here: https://github.com/workhorsy/py-cpuinfo + + diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/RECORD b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..b1e4a940cd272e578100fe8a279bca9c999626cc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/RECORD @@ -0,0 +1,14 @@ +../../../bin/cpuinfo,sha256=UfxJjvVjzhK2GanXPckc2fcNn-4NDzyKR188s6H-PNo,224 +cpuinfo/__init__.py,sha256=T6gndqGAggfJCu4_iOziTnomCN7KzaAK_OYTewE4FMA,44 +cpuinfo/__main__.py,sha256=nSxC6Hqhi-0lN7Z4WwtKdxQdf3cUJefb5hOahCzh4Yg,33 +cpuinfo/__pycache__/__init__.cpython-311.pyc,, +cpuinfo/__pycache__/__main__.cpython-311.pyc,, +cpuinfo/__pycache__/cpuinfo.cpython-311.pyc,, +cpuinfo/cpuinfo.py,sha256=HHyDlDUNovE3QzJ3hviiM1ngyOC4iD7i6oGiz2iTmVk,84388 +py_cpuinfo-9.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +py_cpuinfo-9.0.0.dist-info/LICENSE,sha256=3br3Y5a_XHqkWXWiHq_i4i7st9paoNt8sOYVL6r-800,1127 +py_cpuinfo-9.0.0.dist-info/METADATA,sha256=rRFelvhFdoYcXnXXYDAbgdIxQ8_iVUa5lUHgEmU3ncE,794 +py_cpuinfo-9.0.0.dist-info/RECORD,, +py_cpuinfo-9.0.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +py_cpuinfo-9.0.0.dist-info/entry_points.txt,sha256=ZwrsclY_xUA0xJZK98bLxBdcowxnkK0ANYUT4FYcZJ8,42 +py_cpuinfo-9.0.0.dist-info/top_level.txt,sha256=XsjpunhkxD4hvznqQjrFNw0rtgizHEOGzewPZY3UEtU,8 diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..becc9a66ea739ba941d48a749e248761cc6e658a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/entry_points.txt b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..c10718f4d497f1e333eaec47651ab41f5d196efc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/entry_points.txt @@ -0,0 +1,3 @@ +[console_scripts] +cpuinfo = cpuinfo:main + diff --git a/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/top_level.txt b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..b53b02d61061b32d70bf375f63e0e5d3ee8d4a1d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/py_cpuinfo-9.0.0.dist-info/top_level.txt @@ -0,0 +1 @@ +cpuinfo diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__init__.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6302a8b97f680277d60237094680c72431969d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/layers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b77c0e5bd4f1caccbd5d82dbfbe299ee2d6b4e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/layers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/models.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8122aa56f0b8b979c6c58a8bc311942078c9e4a6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/request.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/request.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ef7d835618b55b4207935e357b51dcec92d6d9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/request.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffbb9b341afe0bc267dee3427808b223528055d3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/worker_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/worker_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a72f7261d6d407332aa3849ab175584d6da82fb9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/__pycache__/worker_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/layers.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..18e0c5227d45c06b48a62fba99908d7521736b02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/layers.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/models.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a5d2fffad5e62ba4bb69d9f31419bf8a281f3b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/models.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[int, T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def capacity(self) -> int: + raise NotImplementedError + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + raise NotImplementedError + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/request.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/request.py new file mode 100644 index 0000000000000000000000000000000000000000..2b604b91bbb6b43e0e5b63fc555c7cef15c88f3d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/request.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + + +class AdapterRequest(ABC): + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self) -> int: + raise NotImplementedError + + def __post_init__(self) -> None: + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/utils.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2dc5433cc65671fa99e2fb1ec4855cd4f1e2c68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/utils.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/.venv/lib/python3.11/site-packages/vllm/adapter_commons/worker_manager.py b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ce24e08a5b56ef441f60d937c27193d3e446f10f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + raise NotImplementedError + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Set[int]: + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__init__.py b/.venv/lib/python3.11/site-packages/vllm/compilation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecdf784395aff39ca4a41c8237a539c76fde28de Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/backends.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/backends.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc03bc5a49f13abb1112ea0a43738682e04c68f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/backends.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/counter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/counter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20a36e9d56e1eb4ca51df93b10d0aab384d73b4c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/counter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/decorators.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/decorators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ca5ce2a33f015722ed8b5097f9f234c7eeb6a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/decorators.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fix_functionalization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fix_functionalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..590406723d7e0c53944776602062f2135ae282d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fix_functionalization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fusion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1358cd85562e58db6950cab6053941046c29a679 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fusion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fx_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fx_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4e5fdb324d84b9c3c4fa84ffa8e30e52dcc40ef Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/fx_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/inductor_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/inductor_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6fc978b4382a1eb4c8e59062dab5ae50241162 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/inductor_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/monitor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/monitor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41abd16a73a376ee6ee83cd31710232b7dd3c004 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/monitor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/multi_output_match.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/multi_output_match.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628927d31a6a4dbe243bcb3657109b232dc931bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/multi_output_match.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/pass_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c221867863ec30dee7a375a5ca683ebde7ed314 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/pass_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/reshapes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/reshapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69419a63a11a066598e0144e0ef008fbdd3ea8fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/reshapes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/vllm_inductor_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/vllm_inductor_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8b74a07a574b7a2155bb29cd00c48a88f15552 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/vllm_inductor_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/wrapper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ba65fcc12d1bc834c01fe75af5df6e8313c431e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/compilation/__pycache__/wrapper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/inductor_pass.py b/.venv/lib/python3.11/site-packages/vllm/compilation/inductor_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..be663946f4d815db4ca9d48eca2d77f904db7b52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/compilation/inductor_pass.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import inspect +import types +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx + + +class InductorPass(ABC): + """ + General custom inductor pass interface. + TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass + """ + + @abstractmethod + def __call__(self, graph: torch.fx.Graph): + """ + Execute the pass on the given graph. + """ + raise NotImplementedError + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.digest() + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + if uuid is None: + uuid = InductorPass.hash_source(callable) + self._uuid = uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + def __getstate__(self): + """ + Pickling occurs in the Inductor code cache if a pass is not given to + the pass manager but is instead directly added to config as a pass. + See PostGradPassManager for more. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + return self._uuid + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CallableInductorPass") diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/multi_output_match.py b/.venv/lib/python3.11/site-packages/vllm/compilation/multi_output_match.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6a60b25950eb50fce0c9c2b2905cb20ff433f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/compilation/multi_output_match.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +import abc +import operator +from abc import abstractmethod +from typing import Iterable, List, Tuple + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload +from torch.fx import Node + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> List[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/reshapes.py b/.venv/lib/python3.11/site-packages/vllm/compilation/reshapes.py new file mode 100644 index 0000000000000000000000000000000000000000..292baae852822d739808aa9a71b58fee77a251b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/compilation/reshapes.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class RedundantReshapesPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. + + Example graph: + + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_reshapes") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if all( + self.dims_equivalent(s, i_s) + for s, i_s in zip(shape, input_shape)): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes", count) + + self.dump_graph(graph, "after_reshapes") + self.end_and_log() + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/.venv/lib/python3.11/site-packages/vllm/compilation/vllm_inductor_pass.py b/.venv/lib/python3.11/site-packages/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2597e42711fcf99e1dc740463c9fdb7b91e295 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch + +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: CompilationConfig.PassConfig): + self.config = config + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + if stage in self.config.dump_graph_stages: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) diff --git a/.venv/lib/python3.11/site-packages/vllm/usage/__init__.py b/.venv/lib/python3.11/site-packages/vllm/usage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c38f619f5cdb8fb7135a0ebec2249725f41347 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/usage_lib.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/usage_lib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65050c85bc4d8ef36277557e18b33adbd2112477 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/usage/__pycache__/usage_lib.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/usage/usage_lib.py b/.venv/lib/python3.11/site-packages/vllm/usage/usage_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fbbb21c89370a1e8f8b7155c5220acad3d870ff8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/usage/usage_lib.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import json +import logging +import os +import platform +import time +from enum import Enum +from pathlib import Path +from threading import Thread +from typing import Any, Dict, Optional, Union +from uuid import uuid4 + +import cpuinfo +import psutil +import requests +import torch + +import vllm.envs as envs +from vllm.connections import global_http_connection +from vllm.version import __version__ as VLLM_VERSION + +_config_home = envs.VLLM_CONFIG_ROOT +_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "usage_stats.json") +_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "do_not_track") +_USAGE_STATS_ENABLED = None +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER + +_GLOBAL_RUNTIME_DATA: Dict[str, Union[str, int, bool]] = {} + +_USAGE_ENV_VARS_TO_COLLECT = [ + "VLLM_USE_MODELSCOPE", + "VLLM_USE_TRITON_FLASH_ATTN", + "VLLM_ATTENTION_BACKEND", + "VLLM_USE_FLASHINFER_SAMPLER", + "VLLM_PP_LAYER_PARTITION", + "VLLM_USE_TRITON_AWQ", + "VLLM_USE_V1", + "VLLM_ENABLE_V1_MULTIPROCESSING", +] + + +def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None: + """Set global usage data that will be sent with every usage heartbeat.""" + _GLOBAL_RUNTIME_DATA[key] = value + + +def is_usage_stats_enabled(): + """Determine whether or not we can send usage stats to the server. + The logic is as follows: + - By default, it should be enabled. + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 + - DO_NOT_TRACK=1 + - VLLM_NO_USAGE_STATS=1 + - A file in the home directory can disable it if it exists: + - $HOME/.config/vllm/do_not_track + """ + global _USAGE_STATS_ENABLED + if _USAGE_STATS_ENABLED is None: + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS + do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) + + _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats + or do_not_track_file) + return _USAGE_STATS_ENABLED + + +def _get_current_timestamp_ns() -> int: + return int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1e9) + + +def _detect_cloud_provider() -> str: + # Try detecting through vendor file + vendor_files = [ + "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_name", + "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + ] + # Mapping of identifiable strings to cloud providers + cloud_identifiers = { + "amazon": "AWS", + "microsoft corporation": "AZURE", + "google": "GCP", + "oraclecloud": "OCI", + } + + for vendor_file in vendor_files: + path = Path(vendor_file) + if path.is_file(): + file_content = path.read_text().lower() + for identifier, provider in cloud_identifiers.items(): + if identifier in file_content: + return provider + + # Try detecting through environment variables + env_to_cloud_provider = { + "RUNPOD_DC_ID": "RUNPOD", + } + for env_var, provider in env_to_cloud_provider.items(): + if os.environ.get(env_var): + return provider + + return "UNKNOWN" + + +class UsageContext(str, Enum): + UNKNOWN_CONTEXT = "UNKNOWN_CONTEXT" + LLM_CLASS = "LLM_CLASS" + API_SERVER = "API_SERVER" + OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" + ENGINE_CONTEXT = "ENGINE_CONTEXT" + + +class UsageMessage: + """Collect platform information and send it to the usage stats server.""" + + def __init__(self) -> None: + # NOTE: vLLM's server _only_ support flat KV pair. + # Do not use nested fields. + + self.uuid = str(uuid4()) + + # Environment Information + self.provider: Optional[str] = None + self.num_cpu: Optional[int] = None + self.cpu_type: Optional[str] = None + self.cpu_family_model_stepping: Optional[str] = None + self.total_memory: Optional[int] = None + self.architecture: Optional[str] = None + self.platform: Optional[str] = None + self.cuda_runtime: Optional[str] = None + self.gpu_count: Optional[int] = None + self.gpu_type: Optional[str] = None + self.gpu_memory_per_device: Optional[int] = None + self.env_var_json: Optional[str] = None + + # vLLM Information + self.model_architecture: Optional[str] = None + self.vllm_version: Optional[str] = None + self.context: Optional[str] = None + + # Metadata + self.log_time: Optional[int] = None + self.source: Optional[str] = None + + def report_usage(self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[Dict[str, Any]] = None) -> None: + t = Thread(target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True) + t.start() + + def _report_usage_worker(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: Dict[str, Any]) -> None: + self._report_usage_once(model_architecture, usage_context, extra_kvs) + self._report_continous_usage() + + def _report_usage_once(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: Dict[str, Any]) -> None: + # Platform information + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + device_property = torch.cuda.get_device_properties(0) + self.gpu_count = torch.cuda.device_count() + self.gpu_type = device_property.name + self.gpu_memory_per_device = device_property.total_memory + if current_platform.is_cuda(): + self.cuda_runtime = torch.version.cuda + self.provider = _detect_cloud_provider() + self.architecture = platform.machine() + self.platform = platform.platform() + self.total_memory = psutil.virtual_memory().total + + info = cpuinfo.get_cpu_info() + self.num_cpu = info.get("count", None) + self.cpu_type = info.get("brand_raw", "") + self.cpu_family_model_stepping = ",".join([ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")) + ]) + + # vLLM information + self.context = usage_context.value + self.vllm_version = VLLM_VERSION + self.model_architecture = model_architecture + + # Environment variables + self.env_var_json = json.dumps({ + env_var: getattr(envs, env_var) + for env_var in _USAGE_ENV_VARS_TO_COLLECT + }) + + # Metadata + self.log_time = _get_current_timestamp_ns() + self.source = envs.VLLM_USAGE_SOURCE + + data = vars(self) + if extra_kvs: + data.update(extra_kvs) + + self._write_to_file(data) + self._send_to_server(data) + + def _report_continous_usage(self): + """Report usage every 10 minutes. + + This helps us to collect more data points for uptime of vLLM usages. + This function can also help send over performance metrics over time. + """ + while True: + time.sleep(600) + data = { + "uuid": self.uuid, + "log_time": _get_current_timestamp_ns(), + } + data.update(_GLOBAL_RUNTIME_DATA) + + self._write_to_file(data) + self._send_to_server(data) + + def _send_to_server(self, data: Dict[str, Any]) -> None: + try: + global_http_client = global_http_connection.get_sync_client() + global_http_client.post(_USAGE_STATS_SERVER, json=data) + except requests.exceptions.RequestException: + # silently ignore unless we are using debug log + logging.debug("Failed to send usage data to server") + + def _write_to_file(self, data: Dict[str, Any]) -> None: + os.makedirs(os.path.dirname(_USAGE_STATS_JSON_PATH), exist_ok=True) + Path(_USAGE_STATS_JSON_PATH).touch(exist_ok=True) + with open(_USAGE_STATS_JSON_PATH, "a") as f: + json.dump(data, f) + f.write("\n") + + +usage_message = UsageMessage()