|
|
""" |
|
|
Profile Guided Optimization (PGO) implementation for Dynamo. |
|
|
|
|
|
This module provides functionality for caching and managing code state profiles |
|
|
that guide optimization decisions in Dynamo. It implements both local and remote |
|
|
caching mechanisms for storing profile information across runs, handles profile |
|
|
merging across distributed ranks, and manages the lifecycle of profile data |
|
|
during compilation. The profiles track dynamic vs static properties of tensors |
|
|
and help Dynamo make better specialization decisions. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import base64 |
|
|
import copy |
|
|
import dataclasses |
|
|
import enum |
|
|
import functools |
|
|
import logging |
|
|
import os |
|
|
import pickle |
|
|
import re |
|
|
import zlib |
|
|
from collections import defaultdict |
|
|
from typing import Optional, TYPE_CHECKING, TypeVar, Union |
|
|
from typing_extensions import override, Self |
|
|
|
|
|
import torch._dynamo.config |
|
|
import torch._utils_internal |
|
|
import torch.compiler.config |
|
|
import torch.distributed as dist |
|
|
from torch._dynamo.utils import ( |
|
|
CompileEventLogger, |
|
|
dynamo_timed, |
|
|
set_feature_use, |
|
|
warn_once, |
|
|
) |
|
|
from torch._environment import is_fbcode |
|
|
from torch._logging._internal import trace_structured_artifact |
|
|
from torch.compiler._cache import ( |
|
|
CacheArtifact, |
|
|
CacheArtifactFactory, |
|
|
CacheArtifactManager, |
|
|
) |
|
|
from torch.utils._ordered_set import OrderedSet |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
import types |
|
|
|
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator |
|
|
from torch._inductor.remote_cache import JsonDataTy, RemoteCache |
|
|
|
|
|
|
|
|
class ReservedWorkflowIdUserError(ValueError): |
|
|
pass |
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
LOCK_TIMEOUT = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache |
|
|
def _hash_containing_file(filepath: str) -> str: |
|
|
|
|
|
if not os.path.exists(filepath): |
|
|
return filepath |
|
|
|
|
|
with open(filepath, "rb") as file: |
|
|
content = file.read() |
|
|
crc32_value = zlib.crc32(content) |
|
|
hash = format(crc32_value & 0xFFFFFFFF, "08x") |
|
|
return hash |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class CodeId: |
|
|
filename: str |
|
|
firstlineno: int |
|
|
name: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
file_hash: str |
|
|
|
|
|
|
|
|
def __eq__(self, other: object) -> bool: |
|
|
if not isinstance(other, CodeId): |
|
|
return False |
|
|
return ( |
|
|
self.file_hash == other.file_hash |
|
|
and self.firstlineno == other.firstlineno |
|
|
and self.name == other.name |
|
|
) |
|
|
|
|
|
|
|
|
def __hash__(self) -> int: |
|
|
return hash((self.file_hash, self.name, self.firstlineno)) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return f"hash({self.file_hash}){self.filename}:{self.firstlineno}:{self.name}" |
|
|
|
|
|
@staticmethod |
|
|
def make(code: types.CodeType) -> CodeId: |
|
|
return CodeId( |
|
|
code.co_filename, |
|
|
code.co_firstlineno, |
|
|
code.co_name, |
|
|
_hash_containing_file(code.co_filename), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class CodeState: |
|
|
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field( |
|
|
default_factory=lambda: defaultdict(FrameStateSizeEntry) |
|
|
) |
|
|
|
|
|
|
|
|
_INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None |
|
|
_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None |
|
|
_LOGGED_DYNAMIC_ALLOWLIST: bool = False |
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class InferStride: |
|
|
""" |
|
|
Denotes the quantity stride[dim] * size[dim], which is what the stride would |
|
|
be for the next physical dimension that results in a contiguous layout. |
|
|
|
|
|
For example, given size = [2, 3], stride = [3, 1], we can replace this with |
|
|
stride = [InferStride(1), 1], because InferStride(1) = stride[1] * size[1] = 1 * 3 = 3 |
|
|
|
|
|
Indirecting the representation in this way is important for the join operation |
|
|
on strides as if we join [2, 3][3, 1] and [2, 4][4, 1], |
|
|
we don't want [2, None][None, 1] which would get eventually symbolized into |
|
|
[2, s0][s1, 1] (notice that the relationship between s0 and s1 is broken). |
|
|
If we instead rewrite the expressions as InferStride so we have [2, 3][InferStride(1), 1] |
|
|
and [2, 4][InferStride(1), 1] we now join to [2, None][InferStride(1), 1] will |
|
|
result in [2, s0][s0, 1], as desired. |
|
|
""" |
|
|
|
|
|
dim: int |
|
|
|
|
|
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
|
|
|
class AutoUnset(enum.Enum): |
|
|
""" |
|
|
The identity element of our semilattice, a generic "don't know" element that |
|
|
is always subsumed when we get more information. |
|
|
""" |
|
|
|
|
|
token = 0 |
|
|
|
|
|
|
|
|
auto_unset = AutoUnset.token |
|
|
|
|
|
|
|
|
class AutoDynamic(enum.Enum): |
|
|
""" |
|
|
The top element of our (bounded) semilattice, whenever you merge this with |
|
|
any other element you always get it again |
|
|
""" |
|
|
|
|
|
token = 0 |
|
|
|
|
|
|
|
|
auto_dynamic = AutoDynamic.token |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class FrameStateSizeEntry: |
|
|
scalar: Union[int, AutoDynamic, AutoUnset] = dataclasses.field(default=auto_unset) |
|
|
|
|
|
|
|
|
size: Union[AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]] = ( |
|
|
dataclasses.field(default=auto_unset) |
|
|
) |
|
|
stride: Union[ |
|
|
AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...] |
|
|
] = dataclasses.field(default=auto_unset) |
|
|
|
|
|
def render(self) -> str: |
|
|
|
|
|
def render_single(s: Union[int, AutoDynamic, AutoUnset, InferStride]) -> str: |
|
|
if s is auto_dynamic: |
|
|
return "?" |
|
|
elif s is auto_unset: |
|
|
|
|
|
return "auto unset" |
|
|
elif isinstance(s, InferStride): |
|
|
return f"S({s.dim})" |
|
|
else: |
|
|
return str(s) |
|
|
|
|
|
def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str: |
|
|
return "[" + ", ".join(render_single(s) for s in ss) + "]" |
|
|
|
|
|
|
|
|
if self.size is auto_dynamic and self.stride is auto_dynamic: |
|
|
if self.scalar is auto_dynamic: |
|
|
return "fully dynamic scalar or tensor" |
|
|
else: |
|
|
return f"scalar {self.scalar}" |
|
|
elif self.scalar is auto_dynamic: |
|
|
if isinstance(self.size, tuple) and isinstance(self.stride, tuple): |
|
|
return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}" |
|
|
|
|
|
|
|
|
return "unusual {repr(self)}" |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
assert not isinstance(self.scalar, torch.SymInt), self.scalar |
|
|
if isinstance(self.size, tuple): |
|
|
for s in self.size: |
|
|
assert not isinstance(s, torch.SymInt), s |
|
|
if isinstance(self.stride, tuple): |
|
|
for s1 in self.stride: |
|
|
assert not isinstance(s1, torch.SymInt), s1 |
|
|
|
|
|
def is_size_dynamic(self, dim: int) -> bool: |
|
|
if self.size is auto_dynamic: |
|
|
return True |
|
|
if self.size is auto_unset: |
|
|
return False |
|
|
return self.size[dim] is auto_dynamic |
|
|
|
|
|
def is_stride_dynamic(self, dim: int) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not ( |
|
|
isinstance(self.size, tuple) and all(type(s) is int for s in self.size) |
|
|
): |
|
|
return False |
|
|
if self.stride is auto_dynamic: |
|
|
return True |
|
|
if self.stride is auto_unset: |
|
|
return False |
|
|
return self.stride[dim] is auto_dynamic |
|
|
|
|
|
@staticmethod |
|
|
def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]: |
|
|
return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs) |
|
|
|
|
|
@classmethod |
|
|
def make_scalar(cls, x: int) -> FrameStateSizeEntry: |
|
|
return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic) |
|
|
|
|
|
@classmethod |
|
|
def make_tensor( |
|
|
cls, size: tuple[int, ...], stride: tuple[int, ...] |
|
|
) -> FrameStateSizeEntry: |
|
|
return FrameStateSizeEntry( |
|
|
scalar=auto_dynamic, |
|
|
size=cls._munge_symint(size), |
|
|
stride=cls._munge_symint(stride), |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry: |
|
|
return FrameStateSizeEntry( |
|
|
scalar=auto_unset, |
|
|
size=cls._munge_symint(size), |
|
|
stride=auto_unset, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _merge_atom(x: _T, y: _T) -> Union[AutoDynamic, _T]: |
|
|
if x is auto_unset: |
|
|
return y |
|
|
if y is auto_unset: |
|
|
return x |
|
|
if x is auto_dynamic or y is auto_dynamic or x != y: |
|
|
return auto_dynamic |
|
|
return x |
|
|
|
|
|
@classmethod |
|
|
def _merge_atom_tup( |
|
|
cls, |
|
|
xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]], |
|
|
ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]], |
|
|
) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]: |
|
|
if xs is auto_unset: |
|
|
return ys |
|
|
if ys is auto_unset: |
|
|
return xs |
|
|
if xs is auto_dynamic or ys is auto_dynamic: |
|
|
return auto_dynamic |
|
|
if len(xs) != len(ys): |
|
|
return auto_dynamic |
|
|
return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys)) |
|
|
|
|
|
def __ior__(self, other: Self) -> Self: |
|
|
self.scalar = self._merge_atom(self.scalar, other.scalar) |
|
|
self.size = self._merge_atom_tup(self.size, other.size) |
|
|
self.stride = self._merge_atom_tup(self.stride, other.stride) |
|
|
return self |
|
|
|
|
|
|
|
|
def update_automatic_dynamic( |
|
|
tx: InstructionTranslator, |
|
|
name: str, |
|
|
entry: FrameStateSizeEntry, |
|
|
*, |
|
|
is_unspecialized_nn_module: bool = False, |
|
|
) -> FrameStateSizeEntry: |
|
|
code_id = CodeId.make(tx.f_code) |
|
|
frame_state = get_code_state()[code_id] |
|
|
if torch._dynamo.config.automatic_dynamic_shapes: |
|
|
is_update = name in frame_state.automatic_dynamic |
|
|
mut_entry = frame_state.automatic_dynamic[name] |
|
|
old_entry = copy.copy(mut_entry) |
|
|
mut_entry |= entry |
|
|
|
|
|
|
|
|
|
|
|
if is_update and old_entry.scalar != mut_entry.scalar: |
|
|
log.debug( |
|
|
"automatic dynamic int %s val %s != %s", |
|
|
name, |
|
|
entry.scalar, |
|
|
old_entry.scalar, |
|
|
) |
|
|
CompileEventLogger.instant( |
|
|
"automatic_dynamic", |
|
|
{ |
|
|
"name": name, |
|
|
"dim_changed": "scalar", |
|
|
"reason": "scalar change", |
|
|
"cached": str(old_entry.scalar), |
|
|
"new": str(entry.scalar), |
|
|
}, |
|
|
) |
|
|
if is_unspecialized_nn_module: |
|
|
log.info( |
|
|
"%s is converted to a symbolic integer. It is an attribute of a " |
|
|
"user defined nn module class. If you wish to keep it static, you can " |
|
|
"mark the nn module class as `torch._dynamo.mark_static`.", |
|
|
name, |
|
|
) |
|
|
|
|
|
def log_tup( |
|
|
tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None |
|
|
) -> None: |
|
|
entry_tup = ( |
|
|
getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i] |
|
|
) |
|
|
old_entry_tup = ( |
|
|
getattr(old_entry, tup_name) |
|
|
if i is None |
|
|
else getattr(old_entry, tup_name)[i] |
|
|
) |
|
|
log.debug( |
|
|
"automatic dynamic %s %s %s %s != %s", |
|
|
tup_name, |
|
|
name, |
|
|
short_reason, |
|
|
|
|
|
entry_tup, |
|
|
old_entry_tup, |
|
|
) |
|
|
CompileEventLogger.instant( |
|
|
"automatic_dynamic", |
|
|
{ |
|
|
"name": name, |
|
|
"dim_changed": "all" if i is None else i, |
|
|
"reason": long_reason, |
|
|
"cached": str(old_entry_tup), |
|
|
"new": str(entry_tup), |
|
|
}, |
|
|
) |
|
|
|
|
|
if is_update and old_entry.size != mut_entry.size: |
|
|
if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple): |
|
|
if len(old_entry.size) != len(entry.size): |
|
|
log_tup("size", "dim", "dimensionality change") |
|
|
else: |
|
|
for i in range(len(entry.size)): |
|
|
if old_entry.size[i] != entry.size[i]: |
|
|
log_tup("size", f"size({i})", "size change", i) |
|
|
else: |
|
|
log_tup("size", "other", "other") |
|
|
|
|
|
if is_update and old_entry.stride != mut_entry.stride: |
|
|
if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple): |
|
|
if len(old_entry.stride) != len(entry.stride): |
|
|
log_tup("stride", "dim", "dimensionality change") |
|
|
else: |
|
|
for i in range(len(entry.stride)): |
|
|
if old_entry.stride[i] != entry.stride[i]: |
|
|
log_tup("stride", f"stride({i})", "stride change", i) |
|
|
else: |
|
|
log_tup("stride", "other", "other") |
|
|
else: |
|
|
old_entry = frame_state.automatic_dynamic[name] |
|
|
log.debug( |
|
|
"automatic dynamic is off, overwriting int %s val %s -> %s", |
|
|
name, |
|
|
old_entry.scalar, |
|
|
entry.scalar, |
|
|
) |
|
|
frame_state.automatic_dynamic[name] = entry |
|
|
mut_entry = entry |
|
|
|
|
|
return mut_entry |
|
|
|
|
|
|
|
|
def process_automatic_dynamic( |
|
|
tx: InstructionTranslator, |
|
|
name: str, |
|
|
entry: FrameStateSizeEntry, |
|
|
*, |
|
|
is_unspecialized_nn_module: bool = False, |
|
|
) -> FrameStateSizeEntry: |
|
|
if (st := tx.distributed_state) is None: |
|
|
return update_automatic_dynamic( |
|
|
tx, |
|
|
name, |
|
|
entry, |
|
|
is_unspecialized_nn_module=is_unspecialized_nn_module, |
|
|
) |
|
|
elif st.all_states is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.local_state.automatic_dynamic[name] = entry |
|
|
return entry |
|
|
else: |
|
|
|
|
|
|
|
|
res = None |
|
|
for sub_state in st.all_states: |
|
|
if name in sub_state.automatic_dynamic: |
|
|
res = update_automatic_dynamic( |
|
|
tx, |
|
|
name, |
|
|
sub_state.automatic_dynamic[name], |
|
|
is_unspecialized_nn_module=is_unspecialized_nn_module, |
|
|
) |
|
|
assert res is not None |
|
|
return res |
|
|
|
|
|
|
|
|
def format_cache_key(key: str) -> str: |
|
|
|
|
|
|
|
|
rank = None |
|
|
if dist.is_available() and dist.is_initialized(): |
|
|
rank = dist.get_rank() |
|
|
|
|
|
tag = torch.compiler.config.cache_key_tag |
|
|
return f"{key}:{rank}:{tag}" |
|
|
|
|
|
|
|
|
def get_cache_key() -> Optional[str]: |
|
|
|
|
|
if torch.compiler.config.force_disable_caches: |
|
|
warn_once( |
|
|
"dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if (r := torch.compiler.config.job_id) is not None: |
|
|
if r.startswith("mast:"): |
|
|
raise ReservedWorkflowIdUserError( |
|
|
"torch.compiler.config.job_id with prefix 'mast:' is reserved for " |
|
|
"automatically generated job id associated with a specific MAST job " |
|
|
"name and version." |
|
|
) |
|
|
return format_cache_key(r) |
|
|
|
|
|
if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: |
|
|
mast_job_name, mast_job_version = name_version |
|
|
return format_cache_key(f"mast:{mast_job_name}:{mast_job_version}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def get_extra_cache_key(sticky_key: str) -> Optional[str]: |
|
|
if torch.compiler.config.force_disable_caches: |
|
|
warn_once( |
|
|
"dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" |
|
|
) |
|
|
return None |
|
|
|
|
|
return format_cache_key(sticky_key) |
|
|
|
|
|
|
|
|
|
|
|
def code_state_path(cache_key: str) -> Optional[str]: |
|
|
if not torch._dynamo.config.automatic_dynamic_local_pgo: |
|
|
log.debug("automatic_dynamic_local_pgo not enabled") |
|
|
return None |
|
|
|
|
|
from torch._inductor.runtime.runtime_utils import cache_dir |
|
|
|
|
|
code_state_key = re.sub(r'[<>:"/\\|?*]', "_", f"code_state_{cache_key}.pkl") |
|
|
return os.path.join(cache_dir(), "dynamo", code_state_key) |
|
|
|
|
|
|
|
|
def should_use_remote_dynamo_pgo_cache() -> bool: |
|
|
if torch.compiler.config.force_disable_caches: |
|
|
return False |
|
|
|
|
|
if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: |
|
|
return r |
|
|
|
|
|
if not is_fbcode(): |
|
|
return False |
|
|
|
|
|
if torch._utils_internal.is_fb_unit_test(): |
|
|
return False |
|
|
|
|
|
try: |
|
|
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION |
|
|
except ModuleNotFoundError: |
|
|
return False |
|
|
|
|
|
return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( |
|
|
"pytorch/remote_cache:dynamo_pgo_version" |
|
|
) |
|
|
|
|
|
|
|
|
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: |
|
|
from torch._inductor.remote_cache import create_cache |
|
|
|
|
|
if not should_use_remote_dynamo_pgo_cache(): |
|
|
return None |
|
|
|
|
|
return create_cache( |
|
|
"dynamo-pgo", |
|
|
is_fbcode(), |
|
|
"FbRemoteDynamoPGOCache", |
|
|
"RemoteDynamoPGOCache", |
|
|
) |
|
|
|
|
|
|
|
|
def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: |
|
|
dynamic_sources: OrderedSet[str] = OrderedSet() |
|
|
for src, fs in code_state.automatic_dynamic.items(): |
|
|
dynamic = False |
|
|
if isinstance(fs.size, tuple): |
|
|
dynamic = auto_dynamic in fs.size |
|
|
elif fs.scalar == auto_dynamic: |
|
|
dynamic = True |
|
|
if dynamic: |
|
|
dynamic_sources.add(src) |
|
|
return dynamic_sources |
|
|
|
|
|
|
|
|
def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: |
|
|
global _LOGGED_DYNAMIC_ALLOWLIST |
|
|
code_id = CodeId.make(f_code) |
|
|
frame_state = get_code_state()[code_id] |
|
|
frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) |
|
|
if frame_whitelist: |
|
|
with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True): |
|
|
CompileEventLogger.pt2_compile( |
|
|
name, recompile_dynamic_whitelist=frame_whitelist |
|
|
) |
|
|
if not _LOGGED_DYNAMIC_ALLOWLIST: |
|
|
torch._utils_internal.add_mlhub_insight( |
|
|
category="dynamic_shapes_analysis", |
|
|
insight="Dynamic shape recompilation detected", |
|
|
insight_description="PGO detected a recompilation due to dynamic shapes. \ |
|
|
Please follow the instruction from the action link to reduce \ |
|
|
recompilation overhead.", |
|
|
) |
|
|
|
|
|
_LOGGED_DYNAMIC_ALLOWLIST = True |
|
|
|
|
|
|
|
|
def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: |
|
|
code_state_str = "\n".join( |
|
|
f"{k}:\n" |
|
|
+ "\n".join( |
|
|
f" {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items() |
|
|
) |
|
|
for k, v in cs.items() |
|
|
) |
|
|
dynamic_sources: OrderedSet[str] = OrderedSet() |
|
|
for state in cs.values(): |
|
|
dynamic_sources.update(_collect_dynamic_sources(state)) |
|
|
if dynamic_sources: |
|
|
code_state_str += ( |
|
|
"\n\nPGO detected a recompilation due to dynamic shapes. " |
|
|
"To reduce shape recompilations by compiling dynamically to start, " |
|
|
f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"' |
|
|
) |
|
|
return code_state_str |
|
|
|
|
|
|
|
|
def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None: |
|
|
def rank(entry: FrameStateSizeEntry) -> int: |
|
|
if not isinstance(entry.size, tuple): |
|
|
return -1 |
|
|
return len(entry.size) |
|
|
|
|
|
if rank(src) == rank(dst): |
|
|
dst |= src |
|
|
|
|
|
|
|
|
@CacheArtifactFactory.register |
|
|
class PGOCacheArtifact(CacheArtifact): |
|
|
@override |
|
|
def populate_cache(self) -> None: |
|
|
meta = write_local_impl( |
|
|
self._rewrite_cache_key_for_mega_cache(self.key), self.content |
|
|
) |
|
|
assert meta is not None |
|
|
|
|
|
@override |
|
|
@staticmethod |
|
|
def type() -> str: |
|
|
return "pgo" |
|
|
|
|
|
@staticmethod |
|
|
def _rewrite_cache_key_for_mega_cache(original_key: str) -> str: |
|
|
""" |
|
|
The PGO cache artifact key for a MAST job contains the job name and the version. |
|
|
When we want to use the cache artifact on a different MAST job, we need to |
|
|
update the key to use the new MAST job's name and version. |
|
|
""" |
|
|
if not original_key.startswith("mast:"): |
|
|
|
|
|
return original_key |
|
|
if (new_key := get_cache_key()) is not None: |
|
|
return new_key |
|
|
return original_key |
|
|
|
|
|
|
|
|
def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]: |
|
|
global _INIT_CODE_STATE |
|
|
assert isinstance(_CODE_STATE, defaultdict) |
|
|
log.info("get_code_state %s hit %s, %d entries", key, ty, len(_CODE_STATE)) |
|
|
trace_structured_artifact( |
|
|
f"get_{ty}_code_state", |
|
|
"string", |
|
|
lambda: render_code_state(_CODE_STATE), |
|
|
) |
|
|
set_feature_use("pgo", True) |
|
|
_INIT_CODE_STATE = copy.deepcopy(_CODE_STATE) |
|
|
return _CODE_STATE |
|
|
|
|
|
|
|
|
def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]: |
|
|
global _CODE_STATE |
|
|
path = code_state_path(cache_key) |
|
|
if path is not None and os.path.exists(path): |
|
|
with dynamo_timed( |
|
|
name := "pgo.get_local_code_state", log_pt2_compile_event=True |
|
|
): |
|
|
CompileEventLogger.pt2_compile(name, cache_key=cache_key) |
|
|
|
|
|
|
|
|
with open(path, "rb") as f: |
|
|
try: |
|
|
content = f.read() |
|
|
_CODE_STATE = pickle.loads(content) |
|
|
CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) |
|
|
except Exception: |
|
|
log.warning( |
|
|
"get_code_state failed while reading %s", path, exc_info=True |
|
|
) |
|
|
else: |
|
|
CacheArtifactManager.record_artifact( |
|
|
PGOCacheArtifact.type(), cache_key, content |
|
|
) |
|
|
return hit(path, "local") |
|
|
return None |
|
|
|
|
|
|
|
|
def lookup_remote_cache_entry( |
|
|
remote_cache: RemoteCache[JsonDataTy], |
|
|
cache_key: str, |
|
|
event_name: Optional[str] = None, |
|
|
) -> Optional[defaultdict[CodeId, CodeState]]: |
|
|
code_state = None |
|
|
try: |
|
|
cache_data = remote_cache.get(cache_key) |
|
|
except Exception: |
|
|
log.warning("get_code_state failed remote read on %s", cache_key, exc_info=True) |
|
|
else: |
|
|
if cache_data is not None: |
|
|
try: |
|
|
assert isinstance(cache_data, dict) |
|
|
data = cache_data["data"] |
|
|
assert isinstance(data, str) |
|
|
payload = base64.b64decode(data) |
|
|
if event_name is not None: |
|
|
CompileEventLogger.pt2_compile( |
|
|
event_name, cache_size_bytes=len(payload) |
|
|
) |
|
|
code_state = pickle.loads(payload) |
|
|
except Exception: |
|
|
log.warning( |
|
|
"get_code_state failed parsing remote result on %s", |
|
|
cache_key, |
|
|
exc_info=True, |
|
|
) |
|
|
else: |
|
|
CacheArtifactManager.record_artifact( |
|
|
PGOCacheArtifact.type(), cache_key, payload |
|
|
) |
|
|
else: |
|
|
log.info("get_code_state remote miss on %s", cache_key) |
|
|
return code_state |
|
|
|
|
|
|
|
|
def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]: |
|
|
global _CODE_STATE |
|
|
remote_cache = get_remote_cache() |
|
|
if remote_cache is not None: |
|
|
with dynamo_timed( |
|
|
name := "pgo.get_remote_code_state", |
|
|
log_pt2_compile_event=True, |
|
|
dynamo_compile_column_us="pgo_get_remote_code_state_time_us", |
|
|
): |
|
|
CompileEventLogger.pt2_compile(name, cache_key=cache_key) |
|
|
code_state = lookup_remote_cache_entry(remote_cache, cache_key, name) |
|
|
if code_state is not None: |
|
|
_CODE_STATE = code_state |
|
|
return hit(cache_key, "remote") |
|
|
return None |
|
|
|
|
|
|
|
|
def add_extra_remote_code_state(cache_key: str) -> None: |
|
|
""" |
|
|
Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile. |
|
|
""" |
|
|
global _CODE_STATE |
|
|
assert _CODE_STATE is not None |
|
|
|
|
|
remote_cache = get_remote_cache() |
|
|
if remote_cache is not None: |
|
|
with dynamo_timed( |
|
|
name := "pgo.add_extra_remote_code_state", |
|
|
log_pt2_compile_event=True, |
|
|
dynamo_compile_column_us="pgo_get_remote_code_state_time_us", |
|
|
): |
|
|
CompileEventLogger.pt2_compile(name, cache_key=cache_key) |
|
|
code_state = lookup_remote_cache_entry(remote_cache, cache_key) |
|
|
log.info( |
|
|
"add_extra_code_state %s hit, %d entries", |
|
|
cache_key, |
|
|
len(code_state) if code_state is not None else 0, |
|
|
) |
|
|
if code_state is not None: |
|
|
|
|
|
for code_id, state in code_state.items(): |
|
|
if code_id in _CODE_STATE: |
|
|
for src, entry in state.automatic_dynamic.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
merge_pgo_entry( |
|
|
entry, _CODE_STATE[code_id].automatic_dynamic[src] |
|
|
) |
|
|
else: |
|
|
_CODE_STATE[code_id] = state |
|
|
|
|
|
trace_structured_artifact( |
|
|
"add_extra_remote_code_state", |
|
|
"string", |
|
|
lambda: render_code_state(code_state), |
|
|
) |
|
|
|
|
|
|
|
|
def get_code_state() -> defaultdict[CodeId, CodeState]: |
|
|
global _CODE_STATE, _INIT_CODE_STATE |
|
|
if _CODE_STATE is not None: |
|
|
return _CODE_STATE |
|
|
|
|
|
|
|
|
_CODE_STATE = defaultdict(CodeState) |
|
|
|
|
|
cache_key = get_cache_key() |
|
|
if cache_key is None: |
|
|
return _CODE_STATE |
|
|
|
|
|
|
|
|
local_code_state = get_local_code_state(cache_key) |
|
|
|
|
|
|
|
|
if local_code_state is None: |
|
|
get_remote_code_state(cache_key) |
|
|
|
|
|
|
|
|
if (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None: |
|
|
extra_read_key = get_extra_cache_key(sticky_read) |
|
|
if extra_read_key is not None: |
|
|
add_extra_remote_code_state(extra_read_key) |
|
|
|
|
|
log.info("get_code_state using default") |
|
|
|
|
|
assert _CODE_STATE is not None |
|
|
return _CODE_STATE |
|
|
|
|
|
|
|
|
def put_code_state() -> None: |
|
|
if _CODE_STATE is None: |
|
|
log.info("put_code_state: never initialized, will not write") |
|
|
return |
|
|
|
|
|
if _CODE_STATE == _INIT_CODE_STATE: |
|
|
log.info("put_code_state: no change, skipping") |
|
|
return |
|
|
|
|
|
cache_key = get_cache_key() |
|
|
if cache_key is None: |
|
|
log.info("put_code_state: no cache key, skipping") |
|
|
return |
|
|
|
|
|
put_local_code_state(cache_key) |
|
|
put_remote_code_state(cache_key) |
|
|
if (sticky_write := torch.compiler.config.pgo_extra_write_key) is not None: |
|
|
extra_write_key = get_extra_cache_key(sticky_write) |
|
|
if extra_write_key is not None: |
|
|
put_remote_code_state(extra_write_key) |
|
|
|
|
|
|
|
|
def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]: |
|
|
path = code_state_path(cache_key) |
|
|
|
|
|
if path is None: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp_path = path + ".tmp" |
|
|
lock_path = path + ".lock" |
|
|
|
|
|
|
|
|
from torch.utils._filelock import FileLock |
|
|
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
|
|
|
with FileLock(lock_path, timeout=LOCK_TIMEOUT): |
|
|
with open(tmp_path, "wb") as f: |
|
|
f.write(pickled_code) |
|
|
size = f.tell() |
|
|
os.replace(tmp_path, path) |
|
|
return path, size |
|
|
|
|
|
|
|
|
def put_local_code_state(cache_key: str) -> None: |
|
|
with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True): |
|
|
CompileEventLogger.pt2_compile(name, cache_key=cache_key) |
|
|
assert _CODE_STATE is not None |
|
|
|
|
|
pickled_code = pickle.dumps(_CODE_STATE) |
|
|
|
|
|
CacheArtifactManager.record_artifact( |
|
|
PGOCacheArtifact.type(), cache_key, pickled_code |
|
|
) |
|
|
|
|
|
meta = write_local_impl(cache_key, pickled_code) |
|
|
if meta is None: |
|
|
log.info("put_code_state: local cache disabled") |
|
|
return |
|
|
path, size = meta |
|
|
|
|
|
CompileEventLogger.pt2_compile(name, cache_size_bytes=size) |
|
|
log.info("put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE)) |
|
|
trace_structured_artifact( |
|
|
"put_local_code_state", |
|
|
"string", |
|
|
lambda: render_code_state(_CODE_STATE), |
|
|
) |
|
|
|
|
|
|
|
|
def put_remote_code_state(cache_key: str) -> None: |
|
|
with dynamo_timed( |
|
|
name := "pgo.put_remote_code_state", |
|
|
log_pt2_compile_event=True, |
|
|
dynamo_compile_column_us="pgo_put_remote_code_state_time_us", |
|
|
): |
|
|
CompileEventLogger.pt2_compile(name, cache_key=cache_key) |
|
|
assert _CODE_STATE is not None |
|
|
|
|
|
remote_cache = get_remote_cache() |
|
|
|
|
|
if remote_cache is None: |
|
|
log.info("put_code_state: remote cache disabled") |
|
|
return |
|
|
|
|
|
content = pickle.dumps(_CODE_STATE) |
|
|
CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content)) |
|
|
cache_data: JsonDataTy = { |
|
|
"data": base64.b64encode(content).decode("ascii"), |
|
|
} |
|
|
remote_cache.put(cache_key, cache_data) |
|
|
log.info( |
|
|
"put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE) |
|
|
) |
|
|
|
|
|
trace_structured_artifact( |
|
|
"put_remote_code_state", |
|
|
"string", |
|
|
lambda: render_code_state(_CODE_STATE), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def reset_code_state() -> None: |
|
|
global _CODE_STATE, _INIT_CODE_STATE, _LOGGED_DYNAMIC_ALLOWLIST |
|
|
_CODE_STATE = None |
|
|
_INIT_CODE_STATE = None |
|
|
_LOGGED_DYNAMIC_ALLOWLIST = False |
|
|
|