| | import os |
| | import sys |
| | from abc import abstractmethod |
| | from contextlib import contextmanager |
| | from types import CodeType |
| | from typing import Callable, List |
| |
|
| | import torch |
| |
|
| |
|
| | class TorchCompileWrapperWithCustomDispatcher: |
| | """ |
| | A wrapper class for torch.compile, with a custom dispatch logic. |
| | Subclasses should: |
| | 1. Implement the forward method |
| | 2. Implement the dispatch logic in the __call__ method |
| | It can use `self.compiled_codes` to access the compiled bytecode, |
| | and `with self.dispatch_to_code(index):` to dispatch to |
| | the compiled code. |
| | 3. Implement the `__init__` method to determine how to call |
| | `torch.compile` over the forward method. |
| | """ |
| |
|
| | def __init__(self, compiled_callable: Callable, use_custom_dispatcher: bool = True): |
| | self.compiled_callable = compiled_callable |
| | self.original_code_object = self.__class__.forward.__code__ |
| | self.compiled_codes: List[CodeType] = [] |
| | torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) |
| |
|
| | self.use_custom_dispatcher: bool = use_custom_dispatcher |
| |
|
| | def __call__(self, *args, **kwargs): |
| | """Implement the dispatch logic here, beyond the torch.compile level. |
| | NOTE: this function can have additional arguments beyond the forward |
| | method, for directly dispatching to the compiled code. |
| | """ |
| | return self.compiled_callable(*args, **kwargs) |
| |
|
| | @abstractmethod |
| | def forward(self, *args, **kwargs): |
| | ... |
| |
|
| | def bytecode_hook(self, old_code: CodeType, new_code: CodeType): |
| | """Hook to save the compiled bytecode for direct execution.""" |
| | if old_code is not self.original_code_object: |
| | return |
| | frame = sys._getframe() |
| | while True: |
| | frame = frame.f_back |
| | code_name = frame.f_code.co_name |
| | file_name = frame.f_code.co_filename.split(os.path.sep)[-1] |
| | if code_name == "_compile" and file_name == "convert_frame.py": |
| | break |
| | frame = frame.f_locals["frame"] |
| | assert frame.f_code == old_code |
| |
|
| | if frame.f_locals["self"] is not self: |
| | return |
| |
|
| | self.compiled_codes.append(new_code) |
| |
|
| | @contextmanager |
| | def dispatch_to_code(self, index: int): |
| | """Context manager to dispatch to the compiled code. |
| | Why does this work? Because Dynamo guarantees that the compiled |
| | bytecode has exactly the same arguments, cell variables, and free |
| | variables as the original code. Therefore we can directly switch |
| | the code object in the function and call it. |
| | |
| | See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. |
| | """ |
| | self.__class__.forward.__code__ = self.compiled_codes[index] |
| | yield |
| | self.__class__.forward.__code__ = self.original_code_object |
| |
|