|
|
""" |
|
|
This module provides the infrastructure for creating and managing compile package |
|
|
for torch.compile. We mainly have two abstractions here: |
|
|
- CompilePackage: Overarching data structure for store and lookup a list of compiled codes. |
|
|
- CodeCacheEntry: Data structure for a single code being compiled by torch.compile. |
|
|
The caching behavior is always under user control explicitly so that a stronger guarantee can |
|
|
be provided about cache hit for a specific compiled model. Users can load the compile package |
|
|
from a different process or host. |
|
|
""" |
|
|
|
|
|
import abc |
|
|
import ast |
|
|
import contextlib |
|
|
import dataclasses |
|
|
import functools |
|
|
import hashlib |
|
|
import importlib |
|
|
import inspect |
|
|
import logging |
|
|
import os |
|
|
import pickle |
|
|
import platform |
|
|
import shutil |
|
|
import sys |
|
|
import types |
|
|
from collections.abc import Generator, Iterator |
|
|
from typing import Any, Callable, NewType, Optional |
|
|
from typing_extensions import Never |
|
|
|
|
|
import torch |
|
|
import torch._inductor.package |
|
|
from torch._dynamo.exc import PackageError |
|
|
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext |
|
|
from torch._inductor.runtime.cache_dir_utils import cache_dir |
|
|
from torch.compiler._cache import CacheArtifactFactory |
|
|
|
|
|
from .bytecode_transformation import get_code_keys |
|
|
from .utils import dynamo_timed, increment_frame |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class SerializedCode: |
|
|
co_argcount: int |
|
|
co_posonlyargcount: int |
|
|
co_kwonlyargcount: int |
|
|
co_nlocals: int |
|
|
co_stacksize: int |
|
|
co_flags: int |
|
|
co_code: bytes |
|
|
co_consts: tuple[Any, ...] |
|
|
co_names: tuple[str, ...] |
|
|
co_varnames: tuple[str, ...] |
|
|
co_filename: str |
|
|
co_name: str |
|
|
co_firstlineno: int |
|
|
co_cellvars: tuple[str, ...] |
|
|
co_freevars: tuple[str, ...] |
|
|
co_linetable: Optional[bytes] = None |
|
|
co_qualname: Optional[str] = None |
|
|
co_exceptiontable: Optional[bytes] = None |
|
|
co_lnotab: Optional[str] = None |
|
|
|
|
|
@classmethod |
|
|
@functools.cache |
|
|
def from_code_object(cls, code: types.CodeType) -> "SerializedCode": |
|
|
kwargs = {key: getattr(code, key) for key in get_code_keys()} |
|
|
kwargs["co_consts"] = tuple( |
|
|
cls.from_code_object(c) if isinstance(c, types.CodeType) else c |
|
|
for c in kwargs["co_consts"] |
|
|
) |
|
|
return cls(**kwargs) |
|
|
|
|
|
@classmethod |
|
|
@functools.cache |
|
|
def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType: |
|
|
kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()} |
|
|
kwargs["co_consts"] = tuple( |
|
|
cls.to_code_object(c) if isinstance(c, SerializedCode) else c |
|
|
for c in kwargs["co_consts"] |
|
|
) |
|
|
return types.CodeType( |
|
|
*kwargs.values(), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class _GuardedCodeCacheEntry: |
|
|
""" |
|
|
Contains the serializable information associated with a single compilation in dynamo. |
|
|
To restore an execution of compiled code, we will need to serialize the following data: |
|
|
- Dynamo bytecode for mapping Python inputs/outputs. |
|
|
- Dynamo guards. |
|
|
""" |
|
|
|
|
|
guards_state: bytes |
|
|
dynamo_code: SerializedCode |
|
|
|
|
|
|
|
|
_BackendId = NewType("_BackendId", str) |
|
|
_FunctionId = NewType("_FunctionId", str) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class InlinedSource: |
|
|
module: str |
|
|
firstlineno: int |
|
|
lastlineno: int |
|
|
checksum: str |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class DynamoCaptureOutput: |
|
|
""" |
|
|
Core information generated from Dynamo for fullgraph=True. |
|
|
""" |
|
|
|
|
|
guarded_codes: list[_GuardedCodeCacheEntry] |
|
|
backend_ids: list[_BackendId] |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class _DynamoCodeCacheEntry(DynamoCaptureOutput): |
|
|
""" |
|
|
Contains the serializable information associated with a single code object |
|
|
in dynamo. To restore an execution of compiled code, we will need the following |
|
|
ingredients: |
|
|
1. The "original" code object, which serves as the entry point for eager |
|
|
execution, i.e. the code only executed when there's no cache entry hit. |
|
|
2. The python module name this code object belongs to, for identifying the |
|
|
enclosing global scope to inject compiled and resume functions. |
|
|
3. A list of function names that pointing to this code object. There could be |
|
|
multiple function objects pointing to the same code such as recursive functions. |
|
|
4. A list of guarded code that eval frame dispatches to. |
|
|
5. A list of imported module objects unioned from all compiled branches. |
|
|
6. A list of "backends" (compiled fx graph) unioned from all compield branches. |
|
|
7. A string path used to access the original code object users defined. |
|
|
A code object can be accessed by "{python_module}.{function_name}.{code_source}" . |
|
|
8. A boolean flag indicating whether the function is installed to global scope. |
|
|
9. A boolean flag indicating whether the function has a compile id. |
|
|
10. Whether or not this code entry was bypassed |
|
|
""" |
|
|
|
|
|
python_code: SerializedCode |
|
|
python_module: str |
|
|
function_names: list[_FunctionId] |
|
|
import_sources: dict[str, str] |
|
|
code_source: Optional[str] |
|
|
install_to_global: bool |
|
|
has_compile_id: bool = False |
|
|
bypassed: bool = False |
|
|
|
|
|
|
|
|
def _lookup_code(entry: _DynamoCodeCacheEntry) -> types.CodeType: |
|
|
assert len(entry.function_names) == 1 |
|
|
fn: Any = sys.modules[entry.python_module] |
|
|
parts = entry.function_names[0].split(".") |
|
|
for part in parts: |
|
|
fn = getattr(fn, part) |
|
|
if entry.code_source: |
|
|
parts = entry.code_source.split(".") |
|
|
for part in parts: |
|
|
if part.endswith("]"): |
|
|
index_begin = part.rfind("[") |
|
|
assert isinstance(index_begin, int) and index_begin >= 0 |
|
|
attr = getattr(fn, part[:index_begin], None) |
|
|
if attr is None: |
|
|
raise PackageError(f"Cannot find source for code entry {entry}") |
|
|
fn = attr[ast.literal_eval(part[index_begin + 1 : -1])] |
|
|
else: |
|
|
fn = getattr(fn, part) |
|
|
else: |
|
|
raise PackageError(f"Cannot find source for code entry {entry}") |
|
|
assert isinstance(fn, types.CodeType) |
|
|
return fn |
|
|
|
|
|
|
|
|
def _raise_resolution_error(code: types.CodeType, scope: Any) -> Never: |
|
|
raise PackageError( |
|
|
f"Cannot resolve a fully qualified name for {code}. Lookup scope: {scope}" |
|
|
) |
|
|
|
|
|
|
|
|
def _get_code_source(code: types.CodeType) -> tuple[str, str]: |
|
|
""" |
|
|
Given a code object, return a fully qualified name which will be used as |
|
|
a serialized handle to access the code object from the new process. |
|
|
This is normally a straightforward process, but there are some corner cases: |
|
|
1. When a function is defined with decorator, then this function will be captured |
|
|
inside a closure with the wrapper object. |
|
|
2. When a function is defined as a nested function, then the code object will be |
|
|
stored on the co_consts field of the parent code object by Python compiler. |
|
|
This function handles all of the corner cases above. |
|
|
""" |
|
|
|
|
|
module = inspect.getmodule(code) |
|
|
if module is None: |
|
|
raise PackageError(f"Cannot find module for code {code}") |
|
|
|
|
|
toplevel: Any = module |
|
|
if sys.version_info >= (3, 11): |
|
|
parts = code.co_qualname.split(".") |
|
|
|
|
|
for part in parts: |
|
|
if not hasattr(toplevel, part): |
|
|
_raise_resolution_error(code, toplevel) |
|
|
toplevel = getattr(toplevel, part) |
|
|
if inspect.isfunction(toplevel): |
|
|
break |
|
|
seen = set() |
|
|
|
|
|
def _find_code_source(obj: Any) -> Optional[str]: |
|
|
nonlocal toplevel |
|
|
nonlocal seen |
|
|
if obj in seen: |
|
|
return None |
|
|
|
|
|
seen.add(obj) |
|
|
|
|
|
if inspect.iscode(obj): |
|
|
if obj is code: |
|
|
return "" |
|
|
|
|
|
for i, const in enumerate(obj.co_consts): |
|
|
if (res := _find_code_source(const)) is not None: |
|
|
return f".co_consts[{i}]{res}" |
|
|
|
|
|
if inspect.isfunction(obj): |
|
|
if (res := _find_code_source(obj.__code__)) is not None: |
|
|
toplevel = obj |
|
|
return f".__code__{res}" |
|
|
if obj.__closure__ is not None: |
|
|
for i, cell in enumerate(obj.__closure__): |
|
|
try: |
|
|
cell_contents = cell.cell_contents |
|
|
except ValueError: |
|
|
continue |
|
|
if not ( |
|
|
inspect.isfunction(cell_contents) |
|
|
or inspect.iscode(cell_contents) |
|
|
): |
|
|
continue |
|
|
if (res := _find_code_source(cell_contents)) is not None: |
|
|
toplevel = obj |
|
|
return f".__closure__[{i}].cell_contents{res}" |
|
|
|
|
|
if sys.version_info < (3, 11): |
|
|
if inspect.ismodule(obj): |
|
|
for value in obj.__dict__.values(): |
|
|
if not (inspect.isfunction(value) or inspect.isclass(value)): |
|
|
continue |
|
|
if (res := _find_code_source(value)) is not None: |
|
|
return res |
|
|
|
|
|
if inspect.isclass(obj): |
|
|
for name, value in obj.__dict__.items(): |
|
|
value = getattr(obj, name) |
|
|
if not (inspect.isfunction(value) or inspect.isclass(value)): |
|
|
continue |
|
|
if (res := _find_code_source(value)) is not None: |
|
|
if value.__name__ != name: |
|
|
_raise_resolution_error(code, toplevel) |
|
|
return res |
|
|
return None |
|
|
|
|
|
code_source = _find_code_source(toplevel) |
|
|
if code_source is None: |
|
|
_raise_resolution_error(code, toplevel) |
|
|
return toplevel.__qualname__, code_source.strip(".") |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class _DynamoCacheEntry: |
|
|
codes: list[_DynamoCodeCacheEntry] |
|
|
inlined_sources: set[InlinedSource] |
|
|
python_version: str = platform.python_version() |
|
|
torch_version: str = torch.__version__ |
|
|
|
|
|
@property |
|
|
def backend_ids(self) -> set[_BackendId]: |
|
|
return {backend_id for code in self.codes for backend_id in code.backend_ids} |
|
|
|
|
|
|
|
|
@CacheArtifactFactory.register |
|
|
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]): |
|
|
@staticmethod |
|
|
def type() -> str: |
|
|
return "precompile_dynamo" |
|
|
|
|
|
def after_deserialization(self) -> _DynamoCacheEntry: |
|
|
return pickle.loads(self.content) |
|
|
|
|
|
|
|
|
def _hash_source(source: str) -> str: |
|
|
sha256_hash = hashlib.sha256() |
|
|
sha256_hash.update(source.encode()) |
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
|
|
|
def _get_sourcelines( |
|
|
m: types.ModuleType, firstlineno: int, lastlineno: int |
|
|
) -> list[str]: |
|
|
return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1] |
|
|
|
|
|
|
|
|
def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str: |
|
|
return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno))) |
|
|
|
|
|
|
|
|
def _compile_frame_context( |
|
|
code: types.CodeType, |
|
|
) -> contextlib.AbstractContextManager[None]: |
|
|
from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start |
|
|
from torch._guards import compile_context, CompileContext |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def _ctx() -> Iterator[None]: |
|
|
increment_frame() |
|
|
compile_id = get_compile_id(frame_state={}) |
|
|
with ( |
|
|
compile_context(CompileContext(compile_id)), |
|
|
dynamo_timed( |
|
|
"_compile.compile_inner", |
|
|
phase_name="entire_frame_compile", |
|
|
dynamo_compile_column_us="dynamo_cumulative_compile_time_us", |
|
|
|
|
|
metadata={ |
|
|
"frame_key": str(torch._dynamo.utils.curr_frame), |
|
|
"co_name": code.co_name, |
|
|
"co_filename": code.co_filename, |
|
|
"co_firstlineno": code.co_firstlineno, |
|
|
}, |
|
|
), |
|
|
): |
|
|
log_dynamo_start(code) |
|
|
yield |
|
|
|
|
|
return _ctx() |
|
|
|
|
|
|
|
|
class CompilePackage: |
|
|
""" |
|
|
CompilePackage is considered a low level component and should not be directly exposed to |
|
|
end users. It has the following interface: |
|
|
|
|
|
1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. |
|
|
a. when `dynamo` argument is None, it will construct a brand new CompilePackage object. |
|
|
b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. |
|
|
2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object. |
|
|
3. `package.install(backends) which will handle all the side-effectful global scope |
|
|
updates with compiled functions and resume functions. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
fn: Optional[Callable[..., Any]], |
|
|
dynamo: Optional[_DynamoCacheEntry] = None, |
|
|
ignore_inlined_sources: bool = False, |
|
|
) -> None: |
|
|
self._innermost_fn = None |
|
|
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} |
|
|
|
|
|
self._current_entry: Optional[_DynamoCodeCacheEntry] = None |
|
|
self._installed_globals: dict[types.ModuleType, list[str]] = {} |
|
|
|
|
|
|
|
|
self._cached_backends: dict[_BackendId, Any] = {} |
|
|
self._inlined_sources: set[InlinedSource] = set() |
|
|
self._resume_codes: set[types.CodeType] = set() |
|
|
self._initialized = False |
|
|
if fn is not None: |
|
|
self.initialize(fn, dynamo, ignore_inlined_sources) |
|
|
self.uninstall() |
|
|
self.validate() |
|
|
|
|
|
def is_initialized(self) -> bool: |
|
|
return self._initialized |
|
|
|
|
|
def initialize( |
|
|
self, |
|
|
fn: Any, |
|
|
dynamo: Optional[_DynamoCacheEntry] = None, |
|
|
ignore_inlined_sources: bool = False, |
|
|
) -> None: |
|
|
from .eval_frame import innermost_fn |
|
|
|
|
|
assert not self._initialized |
|
|
self._inlined_sources = set() |
|
|
self._innermost_fn = innermost_fn(fn) |
|
|
assert self._innermost_fn is not None |
|
|
if dynamo is not None: |
|
|
assert isinstance(dynamo, _DynamoCacheEntry) |
|
|
if dynamo.python_version != platform.python_version(): |
|
|
raise RuntimeError( |
|
|
f"Compile package was created with a different Python version: {dynamo.python_version}" |
|
|
) |
|
|
if dynamo.torch_version != torch.__version__: |
|
|
raise RuntimeError( |
|
|
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" |
|
|
) |
|
|
if not ignore_inlined_sources: |
|
|
for code in dynamo.inlined_sources: |
|
|
m = importlib.import_module(code.module) |
|
|
checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno) |
|
|
if checksum != code.checksum: |
|
|
raise RuntimeError( |
|
|
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})" |
|
|
) |
|
|
|
|
|
self._inlined_sources = dynamo.inlined_sources |
|
|
|
|
|
main, *codes = dynamo.codes |
|
|
self._codes = {self._innermost_fn.__code__: main} |
|
|
for code in codes: |
|
|
self._codes[SerializedCode.to_code_object(code.python_code)] = code |
|
|
else: |
|
|
self._add_function( |
|
|
self._innermost_fn.__code__, self._innermost_fn.__module__ |
|
|
) |
|
|
self._initialized = True |
|
|
|
|
|
def _add_function( |
|
|
self, |
|
|
python_code: types.CodeType, |
|
|
python_module: str, |
|
|
function_name: Optional[_FunctionId] = None, |
|
|
code_source: Optional[str] = None, |
|
|
install_to_global: bool = False, |
|
|
) -> None: |
|
|
if python_code not in self._codes: |
|
|
code = _DynamoCodeCacheEntry( |
|
|
python_code=SerializedCode.from_code_object(python_code), |
|
|
python_module=python_module, |
|
|
function_names=[], |
|
|
guarded_codes=[], |
|
|
import_sources={}, |
|
|
backend_ids=[], |
|
|
code_source=code_source, |
|
|
install_to_global=install_to_global, |
|
|
) |
|
|
self._codes[python_code] = code |
|
|
else: |
|
|
code = self._codes[python_code] |
|
|
assert code.python_module == python_module |
|
|
assert code.install_to_global == install_to_global |
|
|
assert code.code_source == code_source |
|
|
|
|
|
if function_name is not None: |
|
|
code.function_names.append(function_name) |
|
|
|
|
|
@property |
|
|
def cached_backends(self) -> dict[_BackendId, Any]: |
|
|
return self._cached_backends |
|
|
|
|
|
@functools.cached_property |
|
|
def source_id(self) -> str: |
|
|
assert self._innermost_fn is not None |
|
|
return CompilePackage.source_id_from_fn(self._innermost_fn) |
|
|
|
|
|
def _add_user_function(self, code: types.CodeType) -> None: |
|
|
function_name, code_source = _get_code_source(code) |
|
|
module = inspect.getmodule(code) |
|
|
if module is None: |
|
|
raise PackageError(f"Cannot find module for code {code}") |
|
|
self._add_function( |
|
|
code, |
|
|
module.__name__, |
|
|
function_name=_FunctionId(function_name), |
|
|
code_source=code_source, |
|
|
) |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def code_context(self, code: types.CodeType) -> Generator[None, None, None]: |
|
|
assert self._current_entry is None |
|
|
|
|
|
|
|
|
|
|
|
if code not in self._codes: |
|
|
self._add_user_function(code) |
|
|
|
|
|
entry = self._codes[code] |
|
|
self._current_entry = entry |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
if ( |
|
|
entry.bypassed |
|
|
): |
|
|
del self._codes[code] |
|
|
entry.has_compile_id = True |
|
|
self._current_entry = None |
|
|
|
|
|
def add_guarded_code( |
|
|
self, |
|
|
guards_state: bytes, |
|
|
dynamo_code: types.CodeType, |
|
|
) -> None: |
|
|
assert self._current_entry is not None |
|
|
if self._current_entry.bypassed: |
|
|
return |
|
|
guarded_code_entry = _GuardedCodeCacheEntry( |
|
|
guards_state=guards_state, |
|
|
dynamo_code=SerializedCode.from_code_object(dynamo_code), |
|
|
) |
|
|
self._current_entry.guarded_codes.append(guarded_code_entry) |
|
|
|
|
|
def add_inlined_source(self, sources: list[types.CodeType]) -> None: |
|
|
assert self._current_entry is not None |
|
|
if self._current_entry.bypassed: |
|
|
return |
|
|
for code in sources: |
|
|
if code in self._resume_codes: |
|
|
continue |
|
|
module = inspect.getmodule(code) |
|
|
if module is None: |
|
|
continue |
|
|
sourcelines, firstlineno = inspect.getsourcelines(code) |
|
|
lastlineno = firstlineno + len(sourcelines) |
|
|
source = "".join(sourcelines) |
|
|
assert source == "".join(_get_sourcelines(module, firstlineno, lastlineno)) |
|
|
self._inlined_sources.add( |
|
|
InlinedSource( |
|
|
module=module.__name__, |
|
|
firstlineno=firstlineno, |
|
|
lastlineno=lastlineno, |
|
|
checksum=_hash_source(source), |
|
|
) |
|
|
) |
|
|
|
|
|
def bypass_current_entry(self) -> None: |
|
|
assert self._current_entry is not None |
|
|
self._current_entry.bypassed = True |
|
|
|
|
|
def add_resume_function( |
|
|
self, |
|
|
python_code: types.CodeType, |
|
|
python_module: str, |
|
|
function_name: Optional[str], |
|
|
) -> None: |
|
|
self._add_function( |
|
|
python_code, |
|
|
python_module, |
|
|
function_name=_FunctionId(function_name) if function_name else None, |
|
|
install_to_global=True, |
|
|
) |
|
|
self._resume_codes.add(python_code) |
|
|
|
|
|
def add_import_source(self, alias: str, module_name: str) -> None: |
|
|
assert self._current_entry is not None |
|
|
self._current_entry.import_sources[alias] = module_name |
|
|
|
|
|
def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None: |
|
|
assert self._current_entry is not None |
|
|
assert backend_id.startswith("__compiled_fn_") |
|
|
backend_id = _BackendId(backend_id) |
|
|
self._current_entry.backend_ids.append(backend_id) |
|
|
if backend is not None: |
|
|
self._cached_backends[backend_id] = backend |
|
|
|
|
|
def validate(self) -> None: |
|
|
assert self._current_entry is None |
|
|
assert self._innermost_fn is not None |
|
|
assert self._initialized |
|
|
assert next(iter(self._codes)) is self._innermost_fn.__code__ |
|
|
|
|
|
def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: |
|
|
module.__dict__[name] = value |
|
|
self._installed_globals.setdefault(module, []).append(name) |
|
|
|
|
|
def uninstall(self) -> None: |
|
|
from torch._C._dynamo.eval_frame import _reset_precompile_entries |
|
|
|
|
|
assert self._innermost_fn is not None |
|
|
for module, names in self._installed_globals.items(): |
|
|
for name in names: |
|
|
module.__dict__.pop(name) |
|
|
|
|
|
self._installed_globals = {} |
|
|
|
|
|
_reset_precompile_entries(self._innermost_fn.__code__) |
|
|
|
|
|
def install(self, backends: dict[_BackendId, Any]) -> None: |
|
|
""" |
|
|
Sync the package states to the compiled function. This includes the following actions: |
|
|
1. Clean up the previously installed states. |
|
|
2. Install the compiled functions to global scopes. |
|
|
3. Install the precompiled cache entries to ExtraStates on the code object. |
|
|
""" |
|
|
from torch._C._dynamo.eval_frame import _load_precompile_entry |
|
|
|
|
|
from .output_graph import get_builtins_dict |
|
|
|
|
|
self.uninstall() |
|
|
for code, entry in self._codes.items(): |
|
|
context = ( |
|
|
_compile_frame_context(code) |
|
|
if entry.has_compile_id |
|
|
else contextlib.nullcontext() |
|
|
) |
|
|
with context: |
|
|
module = sys.modules[entry.python_module] |
|
|
for alias, module_name in entry.import_sources.items(): |
|
|
self._install_global( |
|
|
module, alias, importlib.import_module(module_name) |
|
|
) |
|
|
target_code = code |
|
|
if entry.install_to_global: |
|
|
for function_name in entry.function_names: |
|
|
fn = types.FunctionType(code, module.__dict__, function_name) |
|
|
self._install_global(module, function_name, fn) |
|
|
if entry.code_source: |
|
|
target_code = _lookup_code(entry) |
|
|
|
|
|
for backend_id in entry.backend_ids: |
|
|
if backend_id not in backends: |
|
|
raise RuntimeError( |
|
|
f"Backend {backend_id} is not found in the given backends" |
|
|
) |
|
|
with dynamo_timed( |
|
|
"after_deserialization", phase_name="backend_compile" |
|
|
): |
|
|
backend = backends[backend_id].after_deserialization() |
|
|
self._install_global( |
|
|
module, |
|
|
backend_id, |
|
|
torch._dynamo.disable(backend), |
|
|
) |
|
|
|
|
|
if len(entry.guarded_codes) == 0: |
|
|
|
|
|
|
|
|
torch._dynamo.eval_frame.skip_code(target_code) |
|
|
|
|
|
for guarded_code in entry.guarded_codes: |
|
|
guards_state = pickle.loads(guarded_code.guards_state) |
|
|
runtime_global_scope = sys.modules[entry.python_module].__dict__ |
|
|
|
|
|
|
|
|
if ( |
|
|
builtin_dict_name |
|
|
:= guards_state.output_graph.name_of_builtins_dict_key_in_fglobals |
|
|
): |
|
|
builtins_dict = get_builtins_dict(runtime_global_scope) |
|
|
if builtin_dict_name in runtime_global_scope: |
|
|
assert ( |
|
|
runtime_global_scope[builtin_dict_name] is builtins_dict |
|
|
) |
|
|
else: |
|
|
runtime_global_scope[builtin_dict_name] = builtins_dict |
|
|
assert isinstance(guards_state, torch._dynamo.guards.GuardsState) |
|
|
check_fn_manager = torch._dynamo.guards.CheckFunctionManager( |
|
|
target_code, |
|
|
guards_state.output_graph, |
|
|
shape_code_parts=guards_state.shape_code_parts, |
|
|
runtime_global_scope=runtime_global_scope, |
|
|
) |
|
|
_load_precompile_entry( |
|
|
target_code, |
|
|
check_fn_manager.guard_manager, |
|
|
SerializedCode.to_code_object(guarded_code.dynamo_code), |
|
|
) |
|
|
|
|
|
def cache_entry(self) -> _DynamoCacheEntry: |
|
|
self.validate() |
|
|
return _DynamoCacheEntry( |
|
|
codes=list(self._codes.values()), inlined_sources=self._inlined_sources |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def source_id_from_fn(fn: Callable[..., Any]) -> str: |
|
|
from .eval_frame import innermost_fn |
|
|
|
|
|
innermost_fn_ = innermost_fn(fn) |
|
|
|
|
|
sha256_hash = hashlib.sha256() |
|
|
sha256_hash.update(innermost_fn_.__qualname__.encode()) |
|
|
sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode()) |
|
|
return sha256_hash.hexdigest() |
|
|
|
|
|
|
|
|
@CacheArtifactFactory.register |
|
|
class EagerCacheArtifact(PrecompileCacheArtifact[Any]): |
|
|
@staticmethod |
|
|
def type() -> str: |
|
|
return "precompile_eager" |
|
|
|
|
|
def after_deserialization(self) -> Any: |
|
|
return pickle.loads(self.content) |
|
|
|
|
|
|
|
|
_Backends = dict[_BackendId, PrecompileCacheArtifact[Any]] |
|
|
|
|
|
|
|
|
class DynamoStore(abc.ABC): |
|
|
""" |
|
|
A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them. |
|
|
|
|
|
This is an abstract base class for different storage implementations. |
|
|
""" |
|
|
|
|
|
def record_package(self, package: CompilePackage) -> None: |
|
|
""" |
|
|
Records a package to PrecompileContext, so that it can be serialized later. |
|
|
""" |
|
|
cache_entry = package.cache_entry() |
|
|
pickled_result = pickle.dumps(cache_entry) |
|
|
PrecompileContext.record_artifact( |
|
|
_DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result |
|
|
) |
|
|
|
|
|
def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None: |
|
|
""" |
|
|
Records eager fx graphs to PrecompileContext for testing purposes. |
|
|
""" |
|
|
pickled_result = pickle.dumps(backend) |
|
|
PrecompileContext.record_artifact( |
|
|
EagerCacheArtifact.type(), key=backend_id, content=pickled_result |
|
|
) |
|
|
|
|
|
@abc.abstractmethod |
|
|
def clear(self) -> None: ... |
|
|
|
|
|
@abc.abstractmethod |
|
|
def write( |
|
|
self, |
|
|
dynamo: _DynamoCacheEntry, |
|
|
backends: _Backends, |
|
|
path: str, |
|
|
) -> None: |
|
|
""" |
|
|
Abstract method to write dynamo cache entry and backends to storage. |
|
|
|
|
|
Args: |
|
|
dynamo: The dynamo cache entry to write |
|
|
backends: Dictionary of backend content to write |
|
|
path: Path or key to identify where to write the data |
|
|
""" |
|
|
... |
|
|
|
|
|
def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None: |
|
|
""" |
|
|
Saves a package to a given path. Grabs backends from PrecompileContext. |
|
|
""" |
|
|
backend_content: _Backends = {} |
|
|
for backend_id in cache_entry.backend_ids: |
|
|
serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id) |
|
|
if serialized_backend is None: |
|
|
raise RuntimeError( |
|
|
f"Backend {backend_id} is not found in the given backends" |
|
|
) |
|
|
assert isinstance(serialized_backend, PrecompileCacheArtifact) |
|
|
backend_content[backend_id] = serialized_backend |
|
|
|
|
|
self.write(cache_entry, backend_content, key) |
|
|
|
|
|
def save_package(self, package: CompilePackage, key: str) -> None: |
|
|
""" |
|
|
Saves a package to a given path. Grabs backends from PrecompileContext. |
|
|
""" |
|
|
self.record_package(package) |
|
|
cache_entry = package.cache_entry() |
|
|
self.save_cache_entry(cache_entry, key) |
|
|
|
|
|
@abc.abstractmethod |
|
|
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: |
|
|
""" |
|
|
Abstract method to read dynamo cache entry and backends from storage. |
|
|
|
|
|
Args: |
|
|
path: Path or key to identify where to read the data from |
|
|
|
|
|
Returns: |
|
|
A tuple containing (dynamo_cache_entry, backend_content) |
|
|
""" |
|
|
... |
|
|
|
|
|
def load_cache_entry( |
|
|
self, key: str |
|
|
) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]: |
|
|
cache_entry, backend_content = self.read(key) |
|
|
for backend_id, backend in backend_content.items(): |
|
|
PrecompileContext.record_artifact( |
|
|
backend.type(), key=backend.key, content=backend.content |
|
|
) |
|
|
backend_content[backend_id] = backend |
|
|
|
|
|
return cache_entry, backend_content |
|
|
|
|
|
def load_package( |
|
|
self, fn: Any, key: str |
|
|
) -> tuple[CompilePackage, dict[_BackendId, Any]]: |
|
|
""" |
|
|
Loads a package from a given path and returns it plus a list of deserialized backends |
|
|
""" |
|
|
cache_entry, backend_content = self.load_cache_entry(key) |
|
|
package = CompilePackage(fn, cache_entry) |
|
|
return package, backend_content |
|
|
|
|
|
|
|
|
class InMemoryDynamoStore(DynamoStore): |
|
|
""" |
|
|
A DynamoStore implementation that keeps state about CompilePackages in memory. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {} |
|
|
|
|
|
def clear(self) -> None: |
|
|
self.packages.clear() |
|
|
|
|
|
def write( |
|
|
self, |
|
|
dynamo: _DynamoCacheEntry, |
|
|
backends: _Backends, |
|
|
path: str, |
|
|
) -> None: |
|
|
""" |
|
|
Store the dynamo cache entry and backends in memory instead of writing to disk. |
|
|
""" |
|
|
self.packages[path] = (dynamo, backends) |
|
|
|
|
|
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: |
|
|
""" |
|
|
Read dynamo cache entry and backends from memory. |
|
|
""" |
|
|
if path not in self.packages: |
|
|
raise RuntimeError(f"No package found with key {path}") |
|
|
|
|
|
return self.packages[path] |
|
|
|
|
|
|
|
|
class DiskDynamoStore(DynamoStore): |
|
|
""" |
|
|
A DynamoStore implementation that keeps state about CompilePackages on disk. |
|
|
""" |
|
|
|
|
|
def __init__(self, path_prefix: str = ""): |
|
|
""" |
|
|
Initialize a DiskDynamoStore with a path prefix. |
|
|
|
|
|
Args: |
|
|
path_prefix: Prefix directory for where to put CompilePackages on disk |
|
|
""" |
|
|
self.path_prefix = path_prefix |
|
|
|
|
|
def clear(self) -> None: |
|
|
""" |
|
|
Clear all CompilePackages from disk. |
|
|
""" |
|
|
if self.path_prefix: |
|
|
shutil.rmtree(self.path_prefix, ignore_errors=True) |
|
|
|
|
|
def write( |
|
|
self, |
|
|
dynamo: _DynamoCacheEntry, |
|
|
backends: _Backends, |
|
|
path: str, |
|
|
) -> None: |
|
|
""" |
|
|
Write dynamo cache entry and backends to disk. |
|
|
""" |
|
|
path = os.path.join(self.path_prefix, path) if self.path_prefix else path |
|
|
try: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path: |
|
|
pickle.dump(dynamo, dynamo_path) |
|
|
with open(os.path.join(path, "backends"), "wb") as backend_path: |
|
|
pickle.dump(backends, backend_path) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to save package to {path}: {e}") from e |
|
|
|
|
|
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: |
|
|
""" |
|
|
Read dynamo cache entry and backends from disk. |
|
|
""" |
|
|
path = os.path.join(self.path_prefix, path) if self.path_prefix else path |
|
|
try: |
|
|
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path: |
|
|
cache_entry = pickle.load(dynamo_path) |
|
|
with open(os.path.join(path, "backends"), "rb") as backend_path: |
|
|
backend_content = pickle.load(backend_path) |
|
|
return cache_entry, backend_content |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e |
|
|
|
|
|
|
|
|
class DiskDynamoCache(DiskDynamoStore): |
|
|
""" |
|
|
Special DiskDynamoStore which adds some helper functions for automatically |
|
|
tracking paths of packages |
|
|
""" |
|
|
|
|
|
def save(self, package: CompilePackage) -> None: |
|
|
""" |
|
|
Saves a package to a given path. Grabs backends from PrecompileContext. |
|
|
""" |
|
|
key = package.source_id |
|
|
logger.info("Saving CompilePackage for %s", package.source_id) |
|
|
super().save_package(package, key) |
|
|
|
|
|
def load( |
|
|
self, fn: Callable[..., Any] |
|
|
) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]: |
|
|
""" |
|
|
Loads a package from a given path and returns it plus a list of deserialized backends |
|
|
""" |
|
|
key = CompilePackage.source_id_from_fn(fn) |
|
|
logger.info("Loading CompilePackage for %s", key) |
|
|
path = os.path.join(self.path_prefix, key) |
|
|
if os.path.exists(path): |
|
|
try: |
|
|
result = super().load_cache_entry(key) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.warning("Failed to load package from path %s: %s", path, str(e)) |
|
|
return None |
|
|
logger.info("No package found for %s", key) |
|
|
return None |
|
|
|
|
|
def load_and_install_package( |
|
|
self, fn: Callable[..., Any] |
|
|
) -> Optional[CompilePackage]: |
|
|
""" |
|
|
Load directly into a package and install backends |
|
|
""" |
|
|
results = self.load(fn) |
|
|
if results is None: |
|
|
return None |
|
|
else: |
|
|
(entry, backends) = results |
|
|
package = CompilePackage(fn, entry) |
|
|
package.install(backends) |
|
|
return package |
|
|
|
|
|
|
|
|
DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo")) |
|
|
|