| | from __future__ import annotations |
| |
|
| | import base64 |
| | import copyreg |
| | import dataclasses |
| | import functools |
| | import hashlib |
| | import importlib |
| | import importlib.resources |
| | import io |
| | import itertools |
| | import json |
| | import logging |
| | import os |
| | import pickle |
| | import pkgutil |
| | import re |
| | import shlex |
| | import shutil |
| | import struct |
| | import subprocess |
| | import sys |
| | import tempfile |
| | import textwrap |
| | import threading |
| | import warnings |
| | from bisect import bisect_right |
| | from copy import copy |
| | from ctypes import c_void_p, CDLL, cdll |
| | from datetime import timedelta |
| | from functools import lru_cache, partial |
| | from pathlib import Path |
| | from tempfile import _TemporaryFileWrapper |
| | from time import time, time_ns |
| | from types import ModuleType |
| | from typing import ( |
| | Any, |
| | Callable, |
| | cast, |
| | Generic, |
| | NoReturn, |
| | Optional, |
| | TYPE_CHECKING, |
| | TypeVar, |
| | Union, |
| | ) |
| | from typing_extensions import override, Self |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch import SymInt, Tensor |
| | from torch._dynamo.exc import SkipFrame |
| | from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed |
| | from torch._inductor import config, exc, metrics |
| | from torch._inductor.codegen.common import ( |
| | custom_backend_codegen_configs, |
| | custom_backend_passes, |
| | init_backend_registration, |
| | ) |
| | from torch._inductor.codegen.cuda import cuda_env |
| | from torch._inductor.codegen.rocm.compile_command import ( |
| | rocm_compile_command, |
| | rocm_compiler, |
| | ) |
| | from torch._inductor.compile_worker.utils import in_toplevel_process |
| | from torch._inductor.cpp_builder import ( |
| | _LINKER_SCRIPT, |
| | _set_gpu_runtime_env, |
| | _TORCH_PATH, |
| | _transform_cuda_paths, |
| | convert_cubin_to_obj, |
| | CppBuilder, |
| | CppOptions, |
| | CppTorchDeviceOptions, |
| | get_compiler_version_info, |
| | get_ld_and_objcopy, |
| | get_name_and_dir_from_output_file_path, |
| | normalize_path_separator, |
| | run_asm_build_object, |
| | ) |
| | from torch._inductor.cpu_vec_isa import pick_vec_isa |
| | from torch._inductor.custom_graph_pass import ( |
| | CustomGraphModulePass, |
| | CustomGraphPass, |
| | CustomGraphPassType, |
| | CustomPartitionerFn, |
| | CustomPartitionerFnType, |
| | ) |
| | from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param |
| | from torch._inductor.runtime.compile_tasks import _reload_python_module |
| | from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir |
| | from torch._inductor.utils import ( |
| | ALIGN_BYTES, |
| | clear_on_fresh_cache, |
| | is_linux, |
| | is_windows, |
| | ) |
| | from torch._logging import trace_structured |
| | from torch._subclasses.fake_tensor import ( |
| | extract_tensor_metadata, |
| | FakeTensor, |
| | TensorMetadata, |
| | ) |
| | from torch._utils_internal import log_cache_bypass |
| | from torch.compiler import config as cconfig |
| | from torch.compiler._cache import ( |
| | CacheArtifact, |
| | CacheArtifactFactory, |
| | CacheArtifactManager, |
| | ) |
| | from torch.export.pt2_archive._package_weights import TensorProperties, Weights |
| | from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX |
| | from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv |
| | from torch.utils._ordered_set import OrderedSet |
| |
|
| | from .output_code import CompiledFxGraph |
| | from .remote_cache import create_cache |
| | from .runtime import autotune_cache |
| | from .runtime.autotune_cache import AutotuneCacheBundler |
| | from .triton_bundler import TritonBundler |
| | from .virtualized import V |
| |
|
| |
|
| | if config.is_fbcode(): |
| | from triton.fb.build import build_paths |
| |
|
| |
|
| | T = TypeVar("T") |
| |
|
| | if TYPE_CHECKING: |
| | from collections.abc import Generator, KeysView, Sequence |
| | from concurrent.futures import Future |
| |
|
| | from .compile_fx import _CompileFxKwargs |
| | from .cpp_builder import BuildOptionsBase |
| | from .graph import GraphLowering |
| | from .ir import ChoiceCaller |
| | from .output_code import CompiledFxGraphConstants, OutputCode |
| | from .remote_cache import JsonDataTy, RemoteCache |
| | from .runtime.hints import HalideInputSpec, HalideMeta |
| | from .runtime.triton_heuristics import CachingAutotuner |
| | from .utils import InputType |
| |
|
| |
|
| | _IS_WINDOWS = sys.platform == "win32" |
| | LOCK_TIMEOUT = 600 |
| |
|
| | output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") |
| | autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning") |
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | def use_re_build() -> bool: |
| | """ |
| | Use for CUTLASS compilation only right now. |
| | """ |
| | if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()): |
| | from triton.fb.re_build_helper import should_build_locally |
| |
|
| | return not should_build_locally() |
| | return False |
| |
|
| |
|
| | def get_cpp_wrapper_cubin_path_name() -> str: |
| | return "cubin_path" if torch.version.hip is None else "hsaco_path" |
| |
|
| |
|
| | def get_kernel_bin_format(device: str) -> str: |
| | if device == "cuda": |
| | return "cubin" if torch.version.hip is None else "hsaco" |
| | elif device == "xpu": |
| | return "spv" |
| | else: |
| | return "" |
| |
|
| |
|
| | class CacheBase: |
| | @staticmethod |
| | @functools.cache |
| | def get_system() -> dict[str, Any]: |
| | from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key |
| |
|
| | if HAS_TRITON: |
| | |
| | |
| | triton_version = triton_key() |
| | else: |
| | triton_version = None |
| |
|
| | try: |
| | system: dict[str, Any] = { |
| | "device": {"name": None}, |
| | "version": { |
| | "triton": triton_version, |
| | }, |
| | } |
| | device_properties = torch.cuda.get_device_properties( |
| | torch.cuda.current_device() |
| | ) |
| | if torch.version.cuda is not None: |
| | system["device"]["name"] = device_properties.name |
| | system["version"]["cuda"] = torch.version.cuda |
| | else: |
| | system["device"]["name"] = device_properties.gcnArchName |
| | system["version"]["hip"] = torch.version.hip |
| | except (AssertionError, RuntimeError): |
| | |
| | system = {} |
| |
|
| | system["hash"] = hashlib.sha256( |
| | json.dumps(system, sort_keys=True).encode("utf-8") |
| | ).hexdigest() |
| |
|
| | return system |
| |
|
| | @staticmethod |
| | @clear_on_fresh_cache |
| | @functools.cache |
| | def get_local_cache_path() -> Path: |
| | return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) |
| |
|
| | def __init__(self) -> None: |
| | self.system = CacheBase.get_system() |
| |
|
| | def get_local_cache(self) -> dict[str, Any]: |
| | local_cache_path = self.get_local_cache_path() |
| | if not local_cache_path.is_file(): |
| | return {} |
| | with open(local_cache_path) as local_cache_fp: |
| | local_cache = json.load(local_cache_fp) |
| | return local_cache["cache"] |
| |
|
| | def update_local_cache(self, local_cache: dict[str, Any]) -> None: |
| | local_cache_path = self.get_local_cache_path() |
| | write_atomic( |
| | str(local_cache_path), |
| | json.dumps({"system": self.system, "cache": local_cache}, indent=4), |
| | make_dirs=True, |
| | ) |
| |
|
| |
|
| | class LocalCache(CacheBase): |
| | def lookup(self, *keys: str) -> Optional[dict[str, Any]]: |
| | cache = self.get_local_cache() |
| |
|
| | sub_cache = cache |
| | for key in keys: |
| | if key in cache: |
| | sub_cache = cache[key] |
| | else: |
| | return None |
| |
|
| | return sub_cache |
| |
|
| | def set_value(self, *keys: str, value: Any) -> None: |
| | cache = self.get_local_cache() |
| |
|
| | sub_cache = cache |
| | for key in keys[0:-1]: |
| | sub_cache.setdefault(key, {}) |
| | sub_cache = sub_cache[key] |
| | sub_cache[keys[-1]] = value |
| |
|
| | self.update_local_cache(cache) |
| |
|
| |
|
| | class PersistentCache(CacheBase): |
| | def lookup( |
| | self, |
| | choices: list[ChoiceCaller], |
| | op: str, |
| | inputs: str, |
| | benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]], |
| | hint_override: Optional[int] = None, |
| | ) -> dict[ChoiceCaller, float]: |
| | """ |
| | Check to see if we have benchmarked the given choice callers. For each |
| | choice caller: |
| | |
| | 1. Check local_cache[op][inputs][choice][precision], return benchmark if cached. |
| | 2. If benchmark is not None: |
| | a. `max_autotune_gemm=True`: benchmark the choice, update |
| | local_cache[op][inputs][choice], and return the benchmark. |
| | b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. |
| | """ |
| | precision = torch.get_float32_matmul_precision() |
| | cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs |
| |
|
| | timings = {} |
| |
|
| | def check_cache(cache: dict[str, Any]) -> bool: |
| | """Check if `cache` contains data for all the choices""" |
| | hit = True |
| | for choice in choices: |
| | choice_hash = choice.hash_key() |
| | if choice_hash in cache.get(op, {}).get(cache_key, {}).get( |
| | precision, {} |
| | ): |
| | |
| | timings[choice] = cache[op][cache_key][precision][choice_hash] |
| | else: |
| | |
| | hit = False |
| | break |
| | return hit |
| |
|
| | local_cache = self.get_local_cache() if config.autotune_local_cache else {} |
| | if (not check_cache(local_cache)) and (benchmark is not None): |
| | |
| | timings = benchmark(choices) |
| | assert all(choice in timings for choice in choices) |
| | local_cache.setdefault(op, {}) |
| | local_cache[op].setdefault(cache_key, {}).setdefault(precision, {}) |
| | for choice, timing in timings.items(): |
| | local_cache[op][cache_key][precision][choice.hash_key()] = timing |
| |
|
| | self.update_local_cache(local_cache) |
| |
|
| | return timings |
| |
|
| |
|
| | def get_lock_dir() -> str: |
| | lock_dir = os.path.join(cache_dir(), "locks") |
| | if not os.path.exists(lock_dir): |
| | os.makedirs(lock_dir, exist_ok=True) |
| | return lock_dir |
| |
|
| |
|
| | def sha256_hash(data: bytes) -> str: |
| | |
| | return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() |
| |
|
| |
|
| | def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str: |
| | hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") |
| | if extra: |
| | extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") |
| | hashing_str = hashing_str + b"||" + extra_b |
| | return "c" + sha256_hash(hashing_str) |
| |
|
| |
|
| | def get_path( |
| | basename: str, extension: str, specified_dir: str = "" |
| | ) -> tuple[str, str, str]: |
| | if specified_dir: |
| | if os.path.isabs(specified_dir): |
| | subdir = specified_dir |
| | else: |
| | subdir = os.path.join(cache_dir(), specified_dir) |
| | else: |
| | subdir = os.path.join(cache_dir(), basename[1:3]) |
| | path = os.path.join(subdir, f"{basename}.{extension}") |
| | return basename, subdir, path |
| |
|
| |
|
| | def get_hash( |
| | content: Union[str, bytes], extra: str = "", hash_type: str = "code" |
| | ) -> str: |
| | if hash_type in {"amdgcn", "code", "ptx", "spv"}: |
| | return code_hash(content, extra) |
| | if hash_type in {"cubin", "hsaco", "spv"}: |
| | return code_hash(repr(content)) |
| | raise AssertionError(f"Unknown hash type {hash_type}") |
| |
|
| |
|
| | class WritableTempFile: |
| | """ |
| | Avoid "Permission denied error" on Windows: |
| | with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: |
| | # Not writable on Windows: |
| | # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile |
| | |
| | Example: |
| | with WritableTempFile("w", suffix=".gv") as temp_file: |
| | tree.to_dotfile(temp_file.name) |
| | """ |
| |
|
| | def __init__( |
| | self, mode: str = "w", *, encoding: Any = None, suffix: Any = None |
| | ) -> None: |
| | self.mode = mode |
| | self.encoding = encoding |
| | self.suffix = suffix |
| |
|
| | def __enter__(self) -> _TemporaryFileWrapper[Any]: |
| | self.temp_file = tempfile.NamedTemporaryFile( |
| | self.mode, encoding=self.encoding, suffix=self.suffix, delete=False |
| | ) |
| | return self.temp_file |
| |
|
| | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| | self.temp_file.close() |
| | os.unlink(self.temp_file.name) |
| |
|
| |
|
| | def write( |
| | content: Union[str, bytes], |
| | extension: str, |
| | extra: str = "", |
| | hash_type: str = "code", |
| | specified_dir: str = "", |
| | key: Optional[str] = None, |
| | ) -> tuple[str, str]: |
| | if key is None: |
| | |
| | |
| | |
| | key = get_hash(content.strip(), extra, hash_type) |
| | basename, _subdir, path = get_path(key, extension, specified_dir) |
| | if not os.path.exists(path): |
| | write_atomic(path, content, make_dirs=True) |
| | return basename, path |
| |
|
| |
|
| | def write_text(text: str) -> str: |
| | """ |
| | Write the `text` to a file and return the path computed based on the hash. |
| | """ |
| | return write(text, "txt")[1] |
| |
|
| |
|
| | def write_atomic( |
| | path_: str, |
| | content: Union[str, bytes], |
| | make_dirs: bool = False, |
| | encode_utf_8: bool = False, |
| | ) -> None: |
| | |
| | |
| | assert isinstance(content, (str, bytes)), ( |
| | "Only strings and byte arrays can be saved in the cache" |
| | ) |
| | path = Path(path_) |
| | if make_dirs: |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| | tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" |
| | write_mode = "w" if isinstance(content, str) else "wb" |
| | with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: |
| | f.write(content) |
| | try: |
| | tmp_path.rename(target=path) |
| | except FileExistsError: |
| | if not _IS_WINDOWS: |
| | raise |
| | |
| | |
| | |
| | shutil.copy2(src=tmp_path, dst=path) |
| | |
| | os.remove(tmp_path) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class TensorMetadataAndValues: |
| | """ |
| | TensorMetadata plus the elements as a list of raw values. |
| | Used for hashing inlined constants. |
| | """ |
| |
|
| | tensor_metadata: TensorMetadata |
| | values: list[Any] |
| |
|
| |
|
| | def _ident(x: T) -> T: |
| | return x |
| |
|
| |
|
| | def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata: |
| | """ |
| | Extracts the tensor metadata and removes fields of the TensorMetadata |
| | that are not needed for caching |
| | """ |
| | meta = extract_tensor_metadata(t) |
| | if not hasattr(t, "_is_inductor_static"): |
| | meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) |
| |
|
| | return meta |
| |
|
| |
|
| | class FxGraphCachePickler(pickle.Pickler): |
| | """ |
| | Custom pickler to customize the pickling of some objects (Tensors), only for the |
| | purpose of computing a hash for keying into the FxGraphCache. Tensors contain |
| | objects that don't pickle and/or vary between runs, and we want to capture the |
| | data that allow us to compute a stable, but safe hash. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | gm: torch.fx.GraphModule, |
| | has_user_defined_triton_kernels: bool = False, |
| | ) -> None: |
| | """ |
| | Create an FX graph pickler. If include_non_inlined=True, then pickling will |
| | include the _values_ for all Tensors. (Note that any tensors are constants |
| | attached as attributes to the GraphModule). Otherwise, pickling will include |
| | only the metadata for these tensors. |
| | """ |
| | self._stream = io.BytesIO() |
| | super().__init__(self._stream) |
| |
|
| | self.dispatch_table = copyreg.dispatch_table.copy() |
| | self.dispatch_table.update( |
| | { |
| | FakeTensor: functools.partial(self._reduce_fake_tensor), |
| | torch.Tensor: functools.partial(self._reduce_tensor), |
| | torch.nn.parameter.Parameter: functools.partial(self._reduce_tensor), |
| | torch.SymInt: functools.partial(self._reduce_symint), |
| | torch.fx.experimental._backward_state.BackwardState: functools.partial( |
| | self._reduce_unsupported |
| | ), |
| | } |
| | ) |
| | if has_user_defined_triton_kernels: |
| | |
| | self.dispatch_table[gm.__class__] = functools.partial( |
| | self._reduce_graph_module |
| | ) |
| |
|
| | |
| | |
| | self.fast = True |
| |
|
| | def _reduce_fake_tensor( |
| | self, t: Tensor |
| | ) -> tuple[Callable[[T], T], tuple[TensorMetadata]]: |
| | """ |
| | Custom reducer to pickle FakeTensors. |
| | """ |
| | metadata = extract_tensor_metadata_for_cache_key(t) |
| | return (_ident, (metadata,)) |
| |
|
| | def _reduce_tensor( |
| | self, t: Tensor |
| | ) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]: |
| | """ |
| | Custom reducer to pickle Tensors. If we see tensors, we know they're constants |
| | stored as attributes on the GraphModule. |
| | """ |
| | from .graph import GraphLowering |
| |
|
| | if t.is_mkldnn: |
| | |
| | |
| | |
| | raise BypassFxGraphCache("mkldnn tensors unpickleable") |
| |
|
| | metadata = extract_tensor_metadata_for_cache_key(t) |
| |
|
| | |
| | if is_frozen_param(t) and not GraphLowering.can_inline_constant(t): |
| | return (_ident, (metadata,)) |
| |
|
| | |
| | |
| | start = time() |
| | values = t.tolist() |
| | elapsed = time() - start |
| | if elapsed > 1.0: |
| | warnings.warn( |
| | f"FX graph cache copying of a large constant took {elapsed:.1}s. " |
| | "Please file an issue." |
| | ) |
| |
|
| | return (_ident, (TensorMetadataAndValues(metadata, values),)) |
| |
|
| | def _reduce_symint(self, s: SymInt) -> tuple[Callable[[T], T], tuple[str]]: |
| | """ |
| | Custom reducer to pickle SymInts. |
| | """ |
| | |
| | |
| | |
| | return (_ident, (str(s),)) |
| |
|
| | def _reduce_unsupported(self, s: Any) -> NoReturn: |
| | """ |
| | Custom reducer to handle any objects that we don't support and therefore |
| | raise to bypass caching. |
| | """ |
| | raise BypassFxGraphCache("Reduce unsupported") |
| |
|
| | def _reduce_graph_module( |
| | self, gm: torch.fx.GraphModule |
| | ) -> tuple[Any, tuple[dict[str, Any], str]]: |
| | """ |
| | Custom reducer for graph module to handle irrelevant data for user |
| | defined triton kernels |
| | Essentially what we are doing here is a huge hack where user defined |
| | triton kernel contain a dynamo time side table and the arguments to the |
| | call_function are indices into this side table. These arguments are not |
| | for hashing purposes since we included the source code into the cache |
| | key and the numbers are prone to give false negatives due to ordering. |
| | """ |
| | fn, (data, imports) = gm.__reduce__() |
| | code = data["_code"] |
| | code = re.sub(r"kernel_idx = \d+", "", code) |
| | code = re.sub(r"constant_args_idx = \d+", "", code) |
| | data["_code"] = code |
| | return fn, (data, imports) |
| |
|
| | def dumps(self, obj: Any) -> bytes: |
| | """ |
| | Pickle an object and return a byte string. |
| | """ |
| | try: |
| | self.dump(obj) |
| | return self._stream.getvalue() |
| | except (TypeError, AttributeError) as e: |
| | |
| | log.warning("Failed to pickle cache key", exc_info=True) |
| | raise BypassFxGraphCache("Failed to pickle cache key") from e |
| | finally: |
| | |
| | self._stream.seek(0) |
| | self._stream.truncate(0) |
| |
|
| | def get_hash(self, obj: Any) -> str: |
| | """ |
| | Serialize an object and return a hash of the bytes. |
| | """ |
| | serialized_data = self.dumps(obj) |
| | return sha256_hash(serialized_data) |
| |
|
| | def debug_lines(self, inp: FxGraphHashDetails) -> list[str]: |
| | """ |
| | Get a printable string describing in more detail all the attributes |
| | comprising an object. Useful for debugging when one graph hashes |
| | to a different value than another. |
| | """ |
| |
|
| | def get_str(obj: Any) -> str: |
| | if isinstance(obj, torch.Tensor): |
| | return str(extract_tensor_metadata_for_cache_key(obj)) |
| | elif isinstance(obj, bytes): |
| | return "<bytes>" |
| | elif type(obj) in self.dispatch_table: |
| | |
| | return str(self.dispatch_table[type(obj)](obj)[1]) |
| | else: |
| | return str(obj) |
| |
|
| | lines = [] |
| | for attr, obj in vars(inp).items(): |
| | if isinstance(obj, list): |
| | for ii in range(len(obj)): |
| | h = self.get_hash(obj[ii]) |
| | lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") |
| | elif isinstance(obj, dict): |
| | for k, v in obj.items(): |
| | h = self.get_hash(v) |
| | lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") |
| | else: |
| | h = self.get_hash(obj) |
| | lines.append(f"[{h}] {attr}: {get_str(obj)}") |
| | return lines |
| |
|
| |
|
| | def build_code_hash( |
| | roots: list[str] | None, prefix: str, hasher: hashlib._Hash |
| | ) -> None: |
| | for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): |
| | spec = lib.module_finder.find_spec(lib.name, None) |
| | assert spec is not None |
| | module = spec.origin |
| | assert module is not None |
| | with open(module, "rb") as f: |
| | hasher.update(spec.name.encode("utf-8")) |
| | hasher.update(f.read()) |
| | if lib.ispkg: |
| | |
| | build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) |
| |
|
| |
|
| | def torch_key_cache(func: Callable[[], bytes]) -> Callable[[], bytes]: |
| | """ |
| | This function is a reimplementation of functools.lru_cache with a |
| | set function that allows prepopulating the cache. |
| | """ |
| | |
| | _cache: list[bytes] = [] |
| |
|
| | def wrapper() -> bytes: |
| | if len(_cache) == 0: |
| | _cache.append(func()) |
| | return _cache[0] |
| |
|
| | def set_val(val: bytes) -> None: |
| | assert len(_cache) == 0 |
| | _cache.append(val) |
| |
|
| | def clear() -> None: |
| | _cache.clear() |
| |
|
| | wrapper.set = set_val |
| | wrapper.clear = clear |
| | return wrapper |
| |
|
| |
|
| | @torch_key_cache |
| | def torch_key() -> bytes: |
| | """ |
| | Compute a key that contains relevant information about torch source files |
| | """ |
| | with dynamo_timed("inductor_codecache_torch_key", log_pt2_compile_event=False): |
| | if not config.is_fbcode(): |
| |
|
| | def get_code_hash(root: str) -> bytes: |
| | |
| | |
| | |
| | extra_files = ( |
| | "codegen/aoti_runtime/interface.cpp", |
| | "script.ld", |
| | ) |
| | inductor_root = os.path.dirname(__file__) |
| | extra_files = [os.path.join(inductor_root, x) for x in extra_files] |
| | hasher = hashlib.sha256() |
| | hasher.update(torch.__version__.encode("utf-8")) |
| | build_code_hash([root], "", hasher) |
| | for path in extra_files: |
| | if os.path.exists(path): |
| | with open(path, "rb") as f: |
| | hasher.update(f.read()) |
| | return hasher.digest() |
| |
|
| | return get_code_hash(_TORCH_PATH) |
| |
|
| | from libfb.py import parutil |
| |
|
| | return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") |
| |
|
| |
|
| | def get_inductor_root() -> str: |
| | return os.path.dirname(__file__) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class OrderedSetHolder: |
| | """ |
| | See FxGraphHashDetails. Holds a sorted list to support stable hashing |
| | of set kwargs. |
| | """ |
| |
|
| | items: list[Any] |
| |
|
| |
|
| | class BypassFxGraphCache(Exception): |
| | """ |
| | Exception to indicate that the FxGraphCache should be bypassed. |
| | """ |
| |
|
| |
|
| | class FxGraphHashDetails: |
| | """ |
| | Object to capture all the details for a compiled FX graph relevant to computing |
| | a safe and stable cache key. |
| | """ |
| |
|
| | |
| | EXCLUDED_KWARGS = ["graph_id"] |
| |
|
| | def __init__( |
| | self, |
| | gm: torch.fx.GraphModule, |
| | example_inputs: Sequence[InputType], |
| | fx_kwargs: _CompileFxKwargs, |
| | inputs_to_check: Sequence[int], |
| | ) -> None: |
| | self.gm = gm |
| | self.example_inputs = example_inputs |
| | self.cache_key_tag = cconfig.cache_key_tag |
| |
|
| | |
| | |
| | |
| | self.fx_kwargs: dict[str, object] = {} |
| | for k, v in sorted(fx_kwargs.items()): |
| | if k not in self.EXCLUDED_KWARGS: |
| | if type(v) in (set, OrderedSet): |
| | |
| | |
| | self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) |
| | else: |
| | self.fx_kwargs[k] = v |
| |
|
| | from torch._higher_order_ops.triton_kernel_wrap import ( |
| | kernel_side_table, |
| | triton_kernel_wrapper_functional, |
| | triton_kernel_wrapper_mutation, |
| | ) |
| | from torch._inductor.codegen.wrapper import ( |
| | user_defined_triton_kernel_transitive_closure_source_code, |
| | ) |
| |
|
| | |
| | |
| | self.user_defined_triton_source: list[Any] = [] |
| | if gm is not None: |
| | for module in gm.modules(): |
| | if not isinstance(module, torch.fx.GraphModule): |
| | continue |
| | for node in itertools.chain( |
| | module.graph.find_nodes( |
| | op="call_function", target=triton_kernel_wrapper_functional |
| | ), |
| | module.graph.find_nodes( |
| | op="call_function", target=triton_kernel_wrapper_mutation |
| | ), |
| | ): |
| | from triton.runtime.autotuner import Autotuner |
| |
|
| | kernel = kernel_side_table.get_kernel(node.kwargs["kernel_idx"]) |
| | configs = None |
| | if isinstance(kernel, Autotuner): |
| | if kernel.configs: |
| | configs = str( |
| | sorted( |
| | sorted(str(kv) for kv in c.all_kwargs().items()) |
| | for c in kernel.configs |
| | ) |
| | ) |
| | kernel = kernel.fn |
| |
|
| | kernel_source = ( |
| | user_defined_triton_kernel_transitive_closure_source_code( |
| | kernel |
| | ) |
| | ) |
| | constant_args = kernel_side_table.get_constant_args( |
| | node.kwargs["constant_args_idx"] |
| | ) |
| | self.user_defined_triton_source.append( |
| | (kernel_source, constant_args, configs) |
| | ) |
| |
|
| | |
| | self.inputs_to_check = inputs_to_check |
| |
|
| | no_tensor_inputs = not any(isinstance(x, torch.Tensor) for x in example_inputs) |
| | |
| | |
| | |
| | if no_tensor_inputs and torch.accelerator.is_available(): |
| | self.default_cuda_device_index = torch.accelerator.current_device_index() |
| |
|
| | |
| | self.deterministic_algorithms_settings = ( |
| | torch.are_deterministic_algorithms_enabled(), |
| | torch.is_deterministic_algorithms_warn_only_enabled(), |
| | torch.utils.deterministic.fill_uninitialized_memory, |
| | ) |
| |
|
| | |
| | self.cuda_matmul_settings = ( |
| | torch.backends.cuda.matmul.fp32_precision, |
| | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, |
| | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, |
| | ) |
| |
|
| | |
| | self.torch_version = torch_key() |
| | self.system_info = CacheBase.get_system() |
| | self.inductor_config = config.save_config_portable(ignore_private_configs=False) |
| | |
| | self.post_grad_custom_pre_pass = self._get_custom_pass_detail( |
| | config.post_grad_custom_pre_pass |
| | ) |
| | |
| | self.precompile_enabled = torch._functorch.config.bundled_autograd_cache |
| | self.post_grad_custom_post_pass = self._get_custom_pass_detail( |
| | config.post_grad_custom_post_pass |
| | ) |
| | self.joint_custom_pre_pass = self._get_custom_pass_detail( |
| | config.joint_custom_pre_pass |
| | ) |
| | self.joint_custom_post_pass = self._get_custom_pass_detail( |
| | config.joint_custom_post_pass |
| | ) |
| | self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe( |
| | config._pre_fusion_custom_pass |
| | ) |
| | self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe( |
| | config._fuse_ddp_communication_passes |
| | ) |
| |
|
| | |
| | init_backend_registration() |
| | self.custom_backend_passes = tuple( |
| | map(self._get_custom_pass_detail, custom_backend_passes.values()) |
| | ) |
| |
|
| | |
| | self.custom_backend_codegen_configs = { |
| | device: custom_config.save_config_portable(ignore_private_configs=False) |
| | for device, custom_config in custom_backend_codegen_configs.items() |
| | if custom_config is not None |
| | } |
| |
|
| | |
| | self._custom_partitioner_fn = self._get_custom_partitioner_fn_detail( |
| | config.custom_partitioner_fn |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]: |
| | if not custom_pass: |
| | return None |
| | if isinstance(custom_pass, list): |
| | return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass] |
| | if isinstance(custom_pass, str): |
| | return custom_pass |
| | if isinstance(custom_pass, CustomGraphPass): |
| | return custom_pass.uuid() |
| | if callable(custom_pass): |
| | |
| | |
| | return None |
| | raise AssertionError(f"unknown config type: {str(type(custom_pass))}") |
| |
|
| | def _get_custom_pass_detail( |
| | self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass] |
| | ) -> Optional[Any]: |
| | if not custom_pass: |
| | return None |
| | assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) |
| | return custom_pass.uuid() |
| |
|
| | def _get_custom_partitioner_fn_detail( |
| | self, custom_partitioner_fn: CustomPartitionerFnType |
| | ) -> Optional[Any]: |
| | if not custom_partitioner_fn: |
| | return None |
| | assert isinstance(custom_partitioner_fn, CustomPartitionerFn) |
| | return custom_partitioner_fn.uuid() |
| |
|
| |
|
| | def compiled_fx_graph_hash( |
| | gm: torch.fx.GraphModule, |
| | example_inputs: Sequence[InputType], |
| | fx_kwargs: _CompileFxKwargs, |
| | inputs_to_check: Sequence[int], |
| | ) -> tuple[str, list[str]]: |
| | """ |
| | Generate a unique hash of the FX graph for caching. |
| | """ |
| | details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) |
| | has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0 |
| | pickler = FxGraphCachePickler(gm, has_user_defined_triton_kernels) |
| |
|
| | |
| | |
| | key = "f" + pickler.get_hash(details) |
| | debug_lines = pickler.debug_lines(details) |
| | debug_str = "\n".join(debug_lines) |
| | log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") |
| | return key, debug_lines |
| |
|
| |
|
| | def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: |
| | """ |
| | Ephemerally increases the NCCL timeout when compiling for a distributed job |
| | Returns amount of seconds increased |
| | """ |
| | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): |
| | return 0 |
| |
|
| | increased_timeout_sec = int(time_saved_ns // 1e9) |
| |
|
| | if config.is_fbcode(): |
| | fudge_factor = torch._utils_internal.justknobs_getval_int( |
| | "pytorch/remote_cache:ephemeral_timeout_fudge_factor_percentage" |
| | ) |
| | log.info( |
| | "Ephemeral NCCL timeout increase fudge factor %d and original increase value %d", |
| | fudge_factor, |
| | increased_timeout_sec, |
| | ) |
| | increased_timeout_sec += int(increased_timeout_sec * fudge_factor / 100) |
| |
|
| | log.info("Increasing NCCL timeout by %d", increased_timeout_sec) |
| | dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( |
| | timedelta(seconds=increased_timeout_sec) |
| | ) |
| | return increased_timeout_sec |
| |
|
| |
|
| | class GuardedCache(Generic[T]): |
| | """ |
| | Mixin for caches that have guards associated with their entries. |
| | """ |
| |
|
| | @classmethod |
| | def _get_tmp_dir_for_key(cls: type[GuardedCache[T]], _key: str) -> str: |
| | raise NotImplementedError("Implement _get_tmp_dir_for_key on parent class") |
| |
|
| | @classmethod |
| | def iterate_over_candidates( |
| | cls: type[GuardedCache[T]], |
| | local: bool, |
| | remote_cache: Optional[RemoteCache[JsonDataTy]], |
| | key: str, |
| | ) -> Generator[tuple[T, bytes], None, None]: |
| | if local: |
| | subdir = cls._get_tmp_dir_for_key(key) |
| | if os.path.exists(subdir): |
| | for path in sorted(os.listdir(subdir)): |
| | try: |
| | with open(os.path.join(subdir, path), "rb") as f: |
| | content = f.read() |
| | yield pickle.loads(content), content |
| | except Exception: |
| | log.warning( |
| | "fx graph cache unable to load compiled graph", |
| | exc_info=True, |
| | ) |
| |
|
| | if remote_cache: |
| | try: |
| | if (cache_data := remote_cache.get(key)) is not None: |
| | assert isinstance(cache_data, dict) |
| | data = cache_data["data"] |
| | assert isinstance(data, (str, bytes)) |
| | content = base64.b64decode(data) |
| | yield pickle.loads(content), content |
| | except Exception: |
| | log.warning( |
| | "%s unable to load compiled graph", cls.__name__, exc_info=True |
| | ) |
| |
|
| | @classmethod |
| | def find_guarded_entry( |
| | cls: type[GuardedCache[T]], |
| | key: str, |
| | local: bool, |
| | remote_cache: Optional[RemoteCache[JsonDataTy]], |
| | evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool], |
| | hints: list[int], |
| | ) -> tuple[Optional[T], Optional[bytes], dict[str, str]]: |
| | """ |
| | Find the first cache entry in iterate_over_candidates that passes `evaluate_guards`. |
| | |
| | Args: |
| | key: The cache key to look up |
| | local: Whether to check the local cache |
| | remote_cache: The remote cache to check, if any |
| | evaluate_guards: Function that evaluates whether a guard passes the check, |
| | given a list of hint values and the guard expression. |
| | hints: List of symint hints paired with evaluate_guards |
| | |
| | Returns: |
| | A tuple of (graph, pickled_content) if found, or (None, None) if not found |
| | """ |
| | graph = None |
| | pickled_content = None |
| | result_status = "full_miss" |
| | sample_guards_expr = None |
| |
|
| | |
| | |
| |
|
| | for candidate, content in cls.iterate_over_candidates(local, remote_cache, key): |
| | assert hasattr(candidate, "guards_expr") |
| | if not candidate.guards_expr: |
| | |
| | graph = candidate |
| | pickled_content = content |
| | result_status = "hit" |
| | break |
| |
|
| | |
| | |
| | |
| | |
| | hit = bool(evaluate_guards(candidate.guards_expr, hints)) |
| | if hit: |
| | graph = candidate |
| | pickled_content = content |
| | result_status = "hit" |
| | sample_guards_expr = candidate.guards_expr |
| | break |
| | else: |
| | |
| | result_status = "guard_miss" |
| | sample_guards_expr = candidate.guards_expr |
| |
|
| | info = {"cache_status_detailed": result_status} |
| | if sample_guards_expr is not None: |
| | info["cache_status_guard_expr"] = sample_guards_expr |
| | return graph, pickled_content, info |
| |
|
| | @classmethod |
| | def _filter_backed_symints( |
| | cls: type[GuardedCache[T]], inputs: Sequence[InputType] |
| | ) -> list[torch.SymInt]: |
| | """ |
| | Get the backed SymInt objects from the input list. Note that we can never |
| | have guards that depend on unbacked symint. |
| | """ |
| | return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] |
| |
|
| | @classmethod |
| | def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]: |
| | """ |
| | Helper to get the shape env from the tracing context. |
| | """ |
| | ctx = torch._guards.TracingContext.try_get() |
| | if not ctx or not ctx.fake_mode: |
| | return None |
| | return ctx.fake_mode.shape_env |
| |
|
| |
|
| | @CacheArtifactFactory.register |
| | class InductorCacheArtifact(CacheArtifact): |
| | @override |
| | def populate_cache(self) -> None: |
| | FxGraphCache._write_to_local_cache(self.key, self.content) |
| |
|
| | @override |
| | @staticmethod |
| | def type() -> str: |
| | return "inductor" |
| |
|
| |
|
| | class FxGraphCache(GuardedCache[CompiledFxGraph]): |
| | """ |
| | Supports caching and reusing compiled Fx graphs. |
| | |
| | The overall strategy is as follows: |
| | - This cache stores entries on disk. When saving an entry, we can't |
| | serialize callables (that could be C++, Triton, etc.), so we serialize |
| | their own disk cache location. We then recreate the compiled artifact |
| | after fetching from disk. |
| | - For indexing the cache, we gather the fields relevant to identifying an |
| | FxGraph (the graph module, graph inputs, system settings etc.) into an |
| | FxGraphCacheDetails object, pickle it, and compute a hash for the key. |
| | See FxGraphCachePickler. |
| | - Among the metadata we store, we also include a guards expression that's |
| | appropriate for validating any symbols for Tensor arguments that have |
| | symbolic bounds. On cache lookup then, we evaluate those guards in the |
| | current context to validate that a cached entry can be served. |
| | - A given graph could have multiple compiled versions, corresponding to |
| | different sets of guards. Therefore, we store cache entries in the form: |
| | <temp dir>/<fx graph hash>/<serialized metadata> |
| | - On lookup, we compute the key from the graph details, iterate over all |
| | leaf files in the corresponding subdirectory, deserialize the entry, and |
| | evaluate its guards expression. If the evaluation succeeds, we have a |
| | cache hit. If it fails, we compile the graph and store a new entry. |
| | - Finally, on a cache hit, we need to make sure any guards that would |
| | have been created during compilation are added to the current context. |
| | """ |
| |
|
| | |
| | |
| | @staticmethod |
| | def _get_tmp_dir() -> str: |
| | """ |
| | Get the toplevel temporary directory for storing compiled graphs. |
| | """ |
| | return os.path.join(cache_dir(), "fxgraph") |
| |
|
| | @classmethod |
| | def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str: |
| | """ |
| | Return the disk location for a given cache key. |
| | """ |
| | return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) |
| |
|
| | @staticmethod |
| | def cache_hit_post_compile( |
| | graph: CompiledFxGraph, |
| | cache_info: dict[str, Any], |
| | constants: CompiledFxGraphConstants, |
| | ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: |
| | """ |
| | Cache specific post compile steps that need to run if we find a graph in the cache |
| | This includes putting bundled triton artifacts in the right place, |
| | reloading the PyCodeCache artifact, etc. |
| | |
| | These don't always happen (i.e. on a cache miss, so they are in a separate function from |
| | CompiledFxGraph.post_compile) |
| | """ |
| | if bundle := graph._triton_bundle: |
| | triton_bundler_meta = TritonBundler.read_and_emit(bundle) |
| | if (meta := triton_bundler_meta) is not None: |
| | cache_info["triton_bundler_meta"] = str(meta) |
| | CompileEventLogger.try_add_pt2_compile( |
| | "inductor_compile", cached_kernel_names=meta.cached_kernel_names |
| | ) |
| | CompileEventLogger.try_add_pt2_compile( |
| | "AOTAutogradCache.inductor_load", |
| | cached_kernel_names=meta.cached_kernel_names, |
| | ) |
| | if len(meta.cached_kernel_names) > 0: |
| | CompileEventLogger.try_( |
| | CompileEventLogger.increment_toplevel, "num_triton_bundles" |
| | ) |
| |
|
| | try: |
| | artifact_path = graph.after_deserialization(constants) |
| |
|
| | from .graph import GraphLowering |
| |
|
| | |
| | if GraphLowering.save_output_code is not None: |
| | GraphLowering.save_output_code(graph.source_code) |
| |
|
| | except OSError: |
| | |
| | |
| | return None, cache_info |
| |
|
| | inductor_meta = autotune_cache.inductor_meta_from_config() |
| | code = graph.source_code |
| | AutotuneCacheBundler.begin_compile(inductor_meta, code=code) |
| |
|
| | |
| | |
| | |
| | metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) |
| | counters["inductor"] += graph.counter_deltas |
| |
|
| | output_code_log.debug("Output code: \n%s", code) |
| | output_code_log.debug("Output code written to: %s", artifact_path) |
| | |
| | trace_structured( |
| | "artifact", |
| | metadata_fn=lambda: { |
| | "name": "fx_graph_runnable", |
| | "encoding": "string", |
| | }, |
| | payload_fn=lambda: graph.runnable_graph_str, |
| | ) |
| | trace_structured( |
| | "inductor_post_grad_graph", |
| | payload_fn=lambda: graph.inductor_post_grad_graph_str, |
| | ) |
| | trace_structured( |
| | "inductor_output_code", |
| | lambda: {"filename": artifact_path}, |
| | payload_fn=lambda: code, |
| | ) |
| | trace_structured( |
| | "artifact", |
| | metadata_fn=lambda: { |
| | "name": "inductor_provenance_tracking_node_mappings", |
| | "encoding": "json", |
| | }, |
| | payload_fn=lambda: graph.inductor_provenance_mapping_str, |
| | ) |
| | trace_structured( |
| | "artifact", |
| | metadata_fn=lambda: { |
| | "name": "inductor_provenance_tracking_kernel_stack_traces", |
| | "encoding": "json", |
| | }, |
| | payload_fn=lambda: graph.inductor_provenance_stack_traces_str, |
| | ) |
| | return graph, cache_info |
| |
|
| | @staticmethod |
| | def _lookup_graph( |
| | key: str, |
| | example_inputs: Sequence[InputType], |
| | local: bool, |
| | remote_cache: Optional[RemoteCache[JsonDataTy]], |
| | constants: CompiledFxGraphConstants, |
| | evaluate_guards: Optional[ |
| | Callable[[str, Union[list[int], list[torch.SymInt]]], bool] |
| | ] = None, |
| | ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: |
| | """ |
| | Lookup a compiled graph in the cache by key. On a hit, return the |
| | deserialized CompiledFxGraph object. On a miss, return None. |
| | `constants` tracks a list of constants, or a way to obtain the list of constants |
| | associated with a given cache entry |
| | `evaluate_guards` allows AOTAutogradCache and other callers to customize |
| | what constitutes a guard success. Normally, a guard hit happens if |
| | `shape_env.evaluate_guards_expression` returns True. |
| | """ |
| | shape_env = FxGraphCache._get_shape_env() |
| | assert shape_env is not None |
| |
|
| | symints = FxGraphCache._filter_backed_symints(example_inputs) |
| | hints = [hint_int(s) for s in symints] |
| |
|
| | |
| | if config.unsafe_skip_cache_dynamic_shape_guards: |
| | |
| | |
| | evaluate_guards = lambda x, y: True |
| |
|
| | if evaluate_guards is None: |
| | evaluate_guards = shape_env.evaluate_guards_expression |
| |
|
| | cache_info: dict[str, Any] = dict() |
| |
|
| | |
| | graph, pickled_content, guard_info = FxGraphCache.find_guarded_entry( |
| | key, local, remote_cache, evaluate_guards, hints |
| | ) |
| | cache_info.update(guard_info) |
| | if graph is None: |
| | return None, cache_info |
| |
|
| | if pickled_content is not None: |
| | CacheArtifactManager.record_artifact( |
| | InductorCacheArtifact.type(), key, pickled_content |
| | ) |
| |
|
| | |
| | if graph.guards_expr: |
| | check = bool(evaluate_guards(graph.guards_expr, symints)) |
| | assert check is True |
| | log.debug( |
| | "fx graph cache key %s post-load guards: %s", key, shape_env.guards |
| | ) |
| |
|
| | return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants) |
| |
|
| | @staticmethod |
| | def _write_to_local_cache(key: str, content: bytes) -> None: |
| | subdir = FxGraphCache._get_tmp_dir_for_key(key) |
| | if not os.path.exists(subdir): |
| | os.makedirs(subdir, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | path = os.path.join(subdir, sha256_hash(content)) |
| | write_atomic(path, content, make_dirs=True) |
| |
|
| | @staticmethod |
| | def _save_graph( |
| | key: str, |
| | compiled_graph: OutputCode, |
| | example_inputs: Sequence[InputType], |
| | local: bool, |
| | remote_cache: Optional[RemoteCache[JsonDataTy]], |
| | ) -> None: |
| | """ |
| | Store a serialized CompiledFxGraph on disk. |
| | """ |
| | from .compile_fx import CompiledFxGraph |
| |
|
| | assert isinstance(compiled_graph, CompiledFxGraph), ( |
| | f"serialization for {type(compiled_graph)} NYI" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | shape_env = FxGraphCache._get_shape_env() |
| | assert shape_env is not None |
| | symints = FxGraphCache._filter_backed_symints(example_inputs) |
| | guards = shape_env.get_pruned_guards(symints) |
| | compiled_graph.guards_expr = shape_env.produce_guards_expression( |
| | placeholders=symints, guards=guards |
| | ) |
| | disk_compiled_graph = copy(compiled_graph) |
| | disk_compiled_graph.prepare_for_serialization() |
| |
|
| | try: |
| | content = pickle.dumps(disk_compiled_graph) |
| | except Exception: |
| | log.warning( |
| | "fx graph cache unable to serialize compiled graph", exc_info=True |
| | ) |
| | counters["inductor"]["fxgraph_cache_pickle_error"] += 1 |
| | return |
| |
|
| | try: |
| | CacheArtifactManager.record_artifact( |
| | InductorCacheArtifact.type(), key, content |
| | ) |
| | if local: |
| | FxGraphCache._write_to_local_cache(key, content) |
| |
|
| | if remote_cache: |
| | time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) |
| | cache_data: JsonDataTy = { |
| | "data": base64.b64encode(content).decode("ascii"), |
| | "time_taken_ms": time_taken_ms, |
| | } |
| | remote_cache.put(key, cache_data) |
| | except Exception: |
| | log.warning("fx graph unable to write to cache", exc_info=True) |
| | counters["inductor"]["fxgraph_cache_write_error"] += 1 |
| |
|
| | @staticmethod |
| | def _check_for_hop(gm: torch.fx.GraphModule) -> None: |
| | for module in gm.modules(): |
| | if not isinstance(module, torch.fx.GraphModule): |
| | continue |
| | for node in module.graph.nodes: |
| | if ( |
| | isinstance(node.target, torch._ops.HigherOrderOperator) |
| | and not node.target.cacheable() |
| | ): |
| | raise BypassFxGraphCache( |
| | f"Can't cache HigherOrderOperator: {node.target.name()}" |
| | ) |
| | if node.op == "getattr" and isinstance( |
| | getattr(gm, node.target), torch._C.ScriptObject |
| | ): |
| | raise BypassFxGraphCache("Can't cache torchbind objects") |
| |
|
| | @staticmethod |
| | def _check_can_cache(gm: torch.fx.GraphModule) -> None: |
| | """ |
| | Check some conditions that would preclude caching and raise BypassFxGraphCache |
| | to bypass in case caching is not possible. |
| | """ |
| | |
| | |
| | for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): |
| | if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): |
| | raise BypassFxGraphCache("Unsupported post grad custom pass") |
| | |
| | for p in (config.joint_custom_pre_pass, config.joint_custom_post_pass): |
| | if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): |
| | raise BypassFxGraphCache("Unsupported joint custom pass") |
| | |
| | |
| | if config._pre_fusion_custom_pass is not None: |
| | if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass): |
| | raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass") |
| | for p in config._fuse_ddp_communication_passes: |
| | if callable(p) and not isinstance(p, CustomGraphPass): |
| | raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass") |
| |
|
| | |
| | if has_frozen_params(gm) and not torch._utils_internal.justknobs_check( |
| | "pytorch/inductor:allow_freezing_with_caching" |
| | ): |
| | raise BypassFxGraphCache("Skipping graph with frozen constants") |
| |
|
| | if config.aot_inductor.use_runtime_constant_folding: |
| | raise BypassFxGraphCache( |
| | "Runtime constant folding can introduce constants that aren't " |
| | "static across runs" |
| | ) |
| |
|
| | from torch._inductor.compiler_bisector import CompilerBisector |
| |
|
| | if CompilerBisector.bisection_enabled: |
| | log.debug("dont cache graph when bisect enabled") |
| | raise BypassFxGraphCache |
| |
|
| | |
| | |
| | if FxGraphCache._get_shape_env() is None: |
| | log.debug("fx graph cache no shape env") |
| | raise BypassFxGraphCache("No shape env") |
| |
|
| | |
| | FxGraphCache._check_for_hop(gm) |
| |
|
| | @staticmethod |
| | def prepare_key( |
| | gm: torch.fx.GraphModule, |
| | example_inputs: Sequence[InputType], |
| | fx_kwargs: _CompileFxKwargs, |
| | inputs_to_check: Sequence[int], |
| | remote: bool, |
| | ) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]: |
| | """ |
| | Checks that the inductor input is cacheable, then computes |
| | and returns the cache key for the input. |
| | Returns (key_info, cache_info) where: |
| | - key_info is (hash_key, debug_lines), and |
| | - cache_info will contain debug info in the event of BypassFxGraphCache. |
| | |
| | NB: It is possible to have this function return a union instead. But |
| | I personally believe it is more annoying/difficult to read in that format. |
| | """ |
| | try: |
| | FxGraphCache._check_can_cache(gm) |
| | key, debug_lines = compiled_fx_graph_hash( |
| | gm, example_inputs, fx_kwargs, inputs_to_check |
| | ) |
| | except BypassFxGraphCache as e: |
| | counters["inductor"]["fxgraph_cache_bypass"] += 1 |
| | log.info("Bypassing FX Graph Cache because '%s'", e) |
| | if remote: |
| | log_cache_bypass("bypass_fx_graph", str(e)) |
| | cache_info = { |
| | "cache_state": "bypass", |
| | "cache_bypass_reason": str(e), |
| | "cache_event_time": time_ns(), |
| | } |
| | return None, cache_info |
| | |
| | return (key, debug_lines), {} |
| |
|
| | @staticmethod |
| | def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: |
| | """ |
| | Attempts to load the remote cache, returns None on error. |
| | """ |
| | cache_id = "fx-graph-v1" |
| | return create_cache( |
| | cache_id, |
| | config.is_fbcode(), |
| | "FbRemoteFxGraphCache", |
| | "RemoteFxGraphCache", |
| | ) |
| |
|
| | @staticmethod |
| | def load_with_key( |
| | key: str, |
| | debug_lines: list[str], |
| | example_inputs: Sequence[InputType], |
| | local: bool, |
| | remote_cache: Optional[RemoteCache[JsonDataTy]], |
| | is_backward: bool, |
| | constants: CompiledFxGraphConstants, |
| | evaluate_guards: Optional[ |
| | Callable[[str, Union[list[int], list[torch.SymInt]]], bool] |
| | ] = None, |
| | ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: |
| | """ |
| | Lookup the graph with the given key, and return results and metadata. |
| | Doesn't do any logging on its own, because AOTAutograd handles a cache miss |
| | differently from FXGraphCache. |
| | """ |
| | compiled_graph, cache_info = FxGraphCache._lookup_graph( |
| | key, example_inputs, local, remote_cache, constants, evaluate_guards |
| | ) |
| | cache_info = { |
| | **cache_info, |
| | "key": key, |
| | "components": debug_lines, |
| | "cache_event_time": time_ns(), |
| | } |
| | if compiled_graph is not None: |
| | log.info("fx graph cache hit for key %s", key) |
| | counters["inductor"]["fxgraph_cache_hit"] += 1 |
| | cache_info["cache_state"] = "hit" |
| | if remote_cache: |
| | |
| | CompileEventLogger.try_( |
| | CompileEventLogger.increment_toplevel, |
| | "inductor_fx_remote_cache_hit_count", |
| | ) |
| | CompileEventLogger.try_( |
| | CompileEventLogger.add_to_set_toplevel, |
| | "inductor_fx_remote_cache_hit_keys", |
| | key, |
| | ) |
| |
|
| | if (time_saved_ns := compiled_graph._time_taken_ns) is not None: |
| | cache_info["time_saved_ns"] = time_saved_ns |
| | CompileEventLogger.try_( |
| | CompileEventLogger.increment_toplevel, |
| | "distributed_ephemeral_timeout_us", |
| | time_saved_ns // 1000, |
| | ) |
| | if ( |
| | ephemeral_increase |
| | := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) |
| | ) != 0: |
| | cache_info["ephemeral_timeout_increase"] = ephemeral_increase |
| | else: |
| | if remote_cache: |
| | |
| | CompileEventLogger.try_( |
| | CompileEventLogger.increment_toplevel, |
| | "inductor_fx_remote_cache_miss_count", |
| | ) |
| | CompileEventLogger.try_( |
| | CompileEventLogger.add_to_set_toplevel, |
| | "inductor_fx_remote_cache_miss_keys", |
| | key, |
| | ) |
| | log.info("fx graph cache miss for key %s", key) |
| | counters["inductor"]["fxgraph_cache_miss"] += 1 |
| | cache_info["cache_state"] = "miss" |
| |
|
| | return compiled_graph, cache_info |
| |
|
| | @staticmethod |
| | def clear() -> None: |
| | """ |
| | Clear out the on-disk cache. |
| | """ |
| | try: |
| | shutil.rmtree(FxGraphCache._get_tmp_dir()) |
| | except FileNotFoundError: |
| | pass |
| |
|
| |
|
| | @functools.cache |
| | def split_aot_inductor_output_path(path: str) -> tuple[str, str]: |
| | def get_module_ext_type() -> str: |
| | if _IS_WINDOWS: |
| | return ".pyd" |
| | else: |
| | return ".so" |
| |
|
| | """Returns the path where the AOT Inductor compiled kernels are stored.""" |
| | if path.endswith(get_module_ext_type()): |
| | return os.path.split(path) |
| | elif path.endswith(".pt2"): |
| | return os.path.split(path) |
| | else: |
| | return path, "" |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class CudaKernelParamCache: |
| | cache: dict[str, dict[str, Any]] = {} |
| | cache_clear = staticmethod(cache.clear) |
| |
|
| | @classmethod |
| | def set( |
| | cls, |
| | key: str, |
| | params: dict[str, Optional[str]], |
| | cubin: str, |
| | bin_type: str, |
| | asm: Optional[str] = None, |
| | asm_type: Optional[str] = None, |
| | ) -> None: |
| | basename = None |
| | if config.aot_inductor.package_cpp_only: |
| | assert config.triton.unique_kernel_names, ( |
| | "package_cpp_only requires triton kernel names to be unique" |
| | ) |
| | assert params["mangled_name"], "Missing kernel name" |
| | basename = params["mangled_name"] |
| |
|
| | _, bin_path = write( |
| | cubin, |
| | bin_type, |
| | hash_type=bin_type, |
| | specified_dir=split_aot_inductor_output_path( |
| | config.aot_inductor.output_path |
| | )[0], |
| | key=basename, |
| | ) |
| | |
| | basename, _ = get_name_and_dir_from_output_file_path(bin_path) |
| |
|
| | if config.aot_inductor.emit_multi_arch_kernel: |
| | bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} |
| | assert bin_type in bin_type_to_ext.keys(), ( |
| | "multi_arch_kernel_binary only supported in CUDA/XPU" |
| | ) |
| | base_path, _ = os.path.splitext(bin_path) |
| | bin_path = base_path + bin_type_to_ext[bin_type] |
| |
|
| | asm_path: str = "" |
| | if ( |
| | config.aot_inductor.emit_multi_arch_kernel |
| | or config.aot_inductor.package_cpp_only |
| | ): |
| | assert asm, "Missing kernel assembly code" |
| | assert asm_type, "Missing kernel assembly type" |
| | _, asm_path = write( |
| | asm, |
| | asm_type, |
| | hash_type=asm_type, |
| | specified_dir=split_aot_inductor_output_path( |
| | config.aot_inductor.output_path |
| | )[0], |
| | |
| | key=basename, |
| | ) |
| |
|
| | params[get_cpp_wrapper_cubin_path_name()] = bin_path |
| | params["asm"] = asm_path |
| | cls.cache[key] = params |
| |
|
| | @classmethod |
| | def get(cls, key: str) -> Optional[dict[str, Any]]: |
| | return cls.cache.get(key, None) |
| |
|
| | @classmethod |
| | def get_keys(cls) -> KeysView[str]: |
| | return cls.cache.keys() |
| |
|
| |
|
| | class AotCodeCompiler: |
| | """ |
| | Compile AOT Inductor generated code. |
| | """ |
| |
|
| | @classmethod |
| | def compile( |
| | cls, |
| | graph: GraphLowering, |
| | wrapper_code: str, |
| | kernel_code: str, |
| | serialized_extern_kernel_nodes: Optional[str], |
| | *, |
| | device_type: str, |
| | additional_files: list[str], |
| | ) -> Union[list[Union[str, Weights]], str]: |
| | """ |
| | Returns the .so path, or returns a list of files that were generated if |
| | config.aot_inductor.package=True. |
| | """ |
| | generated_files: list[Union[str, Weights]] = additional_files |
| |
|
| | _set_gpu_runtime_env() |
| |
|
| | picked_vec_isa = pick_vec_isa() |
| | vec_isa_cmd_gen = CppBuilder( |
| | name="o", |
| | sources="i", |
| | BuildOption=CppTorchDeviceOptions( |
| | vec_isa=picked_vec_isa, |
| | device_type=device_type, |
| | aot_mode=graph.aot_mode, |
| | ), |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | cpp_command = repr(vec_isa_cmd_gen.get_command_line()) |
| |
|
| | |
| | use_relative_path = ( |
| | config.is_fbcode() and device_type == "cpu" and graph.aot_mode |
| | ) |
| |
|
| | ( |
| | specified_output_path, |
| | specified_artifact_name, |
| | ) = split_aot_inductor_output_path(config.aot_inductor.output_path) |
| |
|
| | |
| | |
| | |
| | if config.aot_inductor.package_cpp_only: |
| | wrapper_code = "\n".join((wrapper_code, kernel_code)) |
| | kernel_code = "" |
| |
|
| | wrapper_key, wrapper_path = write( |
| | wrapper_code, |
| | "wrapper.cpp", |
| | extra=cpp_command, |
| | specified_dir=specified_output_path, |
| | key=config.aot_inductor.model_name_for_generated_files, |
| | ) |
| | kernel_code = ( |
| | f"// Triton kernels are embedded as comments in {wrapper_path}\n" |
| | + kernel_code |
| | ) |
| | _, kernel_path = write( |
| | kernel_code, |
| | "kernel.cpp", |
| | extra=cpp_command, |
| | specified_dir=specified_output_path, |
| | key=config.aot_inductor.model_name_for_generated_files, |
| | ) |
| |
|
| | header_code = "" |
| | header_path = "" |
| | if config.aot_inductor.compile_standalone: |
| | |
| | with open( |
| | os.path.join( |
| | os.path.dirname(os.path.dirname(__file__)), |
| | "csrc", |
| | "inductor", |
| | "aoti_runtime", |
| | "model.h", |
| | ) |
| | ) as f: |
| | |
| | model_class_name = config.aot_inductor.model_name_for_generated_files |
| | class_name = f"AOTInductorModel{model_class_name}" |
| | header_code = f.read() |
| |
|
| | |
| | |
| | header_code = ( |
| | header_code.replace("<AOTInductorModel>", f"<{class_name}>") |
| | .replace("AOTInductorModel(", f"{class_name}(") |
| | .replace("AOTInductorModel :", f"{class_name} :") |
| | ) |
| | _, header_path = write( |
| | header_code, |
| | "h", |
| | specified_dir=specified_output_path, |
| | key=model_class_name, |
| | ) |
| |
|
| | |
| | with WritableTempFile("w+") as t: |
| | """ |
| | Avoid "Permission denied error" on Windows: |
| | with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: |
| | # Not writable on Windows: |
| | # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile |
| | |
| | Example: |
| | with WritableTempFile("w", suffix=".gv") as temp_file: |
| | tree.to_dotfile(temp_file.name) |
| | """ |
| | t.writelines((wrapper_code, "\n", kernel_code, "\n")) |
| | t.flush() |
| | V.debug.output_code(t.name, extension="cpp") |
| |
|
| | if config.aot_inductor.package: |
| | generated_files.append(wrapper_path) |
| | if not config.aot_inductor.package_cpp_only: |
| | generated_files.append(kernel_path) |
| | if config.aot_inductor.compile_standalone: |
| | generated_files.append(header_path) |
| |
|
| | output_code_log.info("Wrapper code written to: %s", wrapper_path) |
| | output_code_log.info("Kernel code written to: %s", kernel_path) |
| | trace_structured( |
| | "graph_dump", |
| | lambda: { |
| | "name": "inductor_aot_wrapper_code", |
| | "type": "cpp", |
| | "filename": wrapper_path, |
| | }, |
| | payload_fn=lambda: wrapper_code, |
| | ) |
| | trace_structured( |
| | "graph_dump", |
| | lambda: { |
| | "name": "inductor_aot_kernel_code", |
| | "type": "cpp", |
| | "filename": kernel_path, |
| | }, |
| | payload_fn=lambda: kernel_code, |
| | ) |
| | if config.aot_inductor.compile_standalone: |
| | output_code_log.info("Header code written to: %s", header_path) |
| | trace_structured( |
| | "graph_dump", |
| | lambda: { |
| | "name": "inductor_aot_header_code", |
| | "type": "cpp", |
| | "filename": header_path, |
| | }, |
| | payload_fn=lambda: header_code, |
| | ) |
| |
|
| | |
| | |
| | |
| | wrapper_path_operator = Path(wrapper_path) |
| | kernel_path_operator = Path(kernel_path) |
| | specified_sub_dir = wrapper_path_operator.parent / wrapper_key |
| | if not specified_sub_dir.exists(): |
| | specified_sub_dir.mkdir(exist_ok=True) |
| | cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt") |
| |
|
| | def _compile_consts(consts: bytes, platform: str) -> str: |
| | |
| | use_asm_build: bool = config.aot_inductor.use_consts_asm_build |
| |
|
| | if platform == "linux": |
| | if graph.mutated_buffers & OrderedSet(graph.constants.keys()): |
| | |
| | |
| | |
| | if len(consts) > 2_000_000_000: |
| | raise ValueError( |
| | "Models with buffer mutation included doesn't support constants greater than 2GB!" |
| | ) |
| | section_attr = '.ldata, "aw"' |
| | else: |
| | section_attr = '.lrodata, "a"' |
| | symbol_prefix = "" |
| | elif platform == "darwin": |
| | section_attr = "__DATA,__data" |
| | symbol_prefix = "_" |
| | elif platform == "win32": |
| | symbol_prefix = "" |
| | |
| | use_asm_build = False |
| | else: |
| | raise RuntimeError(f"Unsupported platform: {platform}") |
| |
|
| | |
| | |
| | if device_type == "xpu": |
| | use_asm_build = False |
| |
|
| | is_large_consts = len(consts) > 1024 |
| | is_zero_size_consts = len(consts) == 0 |
| |
|
| | def format_consts_to_gnu_asm( |
| | consts: bytes, |
| | align_bytes: int, |
| | symbol_prefix: str, |
| | is_large_consts: bool, |
| | ) -> tuple[str, str]: |
| | consts_asm = f"\t.section\t{section_attr}\n" |
| | consts_asm += f"\t.balign {align_bytes}\n" |
| | consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" |
| | consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n" |
| | if not is_large_consts: |
| | for c in consts: |
| | consts_asm += f"\t.byte {c}\n" |
| | |
| | |
| | if not consts: |
| | consts_asm += "\t.space 1\n" |
| | else: |
| | consts_asm += "\t.quad 0x1234567899abcdef\n" |
| | consts_asm += f"\t.space {len(consts) - 8}\n" |
| | consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" |
| | consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" |
| | return consts_asm, "weights.S" |
| |
|
| | |
| | def format_consts_to_cpp( |
| | consts: bytes, align_bytes: int, symbol_prefix: str |
| | ) -> tuple[str, str]: |
| | consts_size = len(consts) |
| | asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\ |
| | #define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\ |
| | #else\t\n\ |
| | #define ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n\ |
| | #endif\t\n\ |
| | \t\n\ |
| | ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n""" |
| | const_cpp = asan_attr |
| | const_cpp += f"alignas({align_bytes}) extern " |
| | const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" |
| | count_bytes = 0 |
| | for c in consts: |
| | const_cpp += f"{c}, " |
| | count_bytes = count_bytes + 1 |
| | if count_bytes % 16 == 0: |
| | const_cpp += "\t\n" |
| | const_cpp += "};\t\n" |
| | const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" |
| | return const_cpp, "weights.cpp" |
| |
|
| | def get_zero_consts_asm_code( |
| | align_bytes: int, |
| | symbol_prefix: str, |
| | ) -> tuple[str, str]: |
| | """ |
| | This function handles zero-sized constants because the C++ standard prohibits zero-length arrays: |
| | https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c |
| | |
| | On Windows (MSVC): |
| | The compiler reports error C2466 for zero-sized arrays: |
| | https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466 |
| | Solution: Use assembly compilation to handle this case. |
| | |
| | Why not use Win32 assembly for all paths? |
| | ml64 only supports alignment up to 16 bytes, which isn't optimal for performance. |
| | |
| | Cross-platform implementation: |
| | Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler |
| | Windows: MSVC naturally rejects zero-sized arrays by default |
| | """ |
| | if _IS_WINDOWS: |
| | |
| | asm_code = """ |
| | option casemap:none |
| | .data |
| | ?_binary_constants_bin_start@@3PAEA: |
| | align 16 |
| | ?_binary_constants_bin_end@@3PAEA: |
| | align 16 |
| | public ?_binary_constants_bin_start@@3PAEA |
| | public ?_binary_constants_bin_end@@3PAEA |
| | end |
| | """ |
| | asm_ext = "asm" |
| | else: |
| | asm_code = f"\t.section\t{section_attr}\n" |
| | asm_code += f"\t.balign {align_bytes}\n" |
| | asm_code += ( |
| | f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" |
| | ) |
| | asm_code += f"{symbol_prefix}_binary_constants_bin_start:\n" |
| | asm_code += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" |
| | asm_code += f"{symbol_prefix}_binary_constants_bin_end:\n" |
| | asm_ext = "S" |
| | return asm_code, asm_ext |
| |
|
| | if use_asm_build: |
| | consts_code, code_ext = format_consts_to_gnu_asm( |
| | consts, ALIGN_BYTES, symbol_prefix, is_large_consts |
| | ) |
| | else: |
| | if is_zero_size_consts: |
| | consts_code, code_ext = get_zero_consts_asm_code( |
| | ALIGN_BYTES, symbol_prefix |
| | ) |
| | else: |
| | consts_code, code_ext = format_consts_to_cpp( |
| | consts, ALIGN_BYTES, symbol_prefix |
| | ) |
| |
|
| | _, consts_s = write( |
| | consts_code, |
| | code_ext, |
| | specified_dir=str(specified_sub_dir), |
| | key=config.aot_inductor.model_name_for_generated_files, |
| | ) |
| | consts_s = Path(consts_s) |
| | object_build_options = CppTorchDeviceOptions( |
| | device_type=device_type, |
| | aot_mode=graph.aot_mode, |
| | compile_only=True, |
| | use_relative_path=use_relative_path, |
| | ) |
| | object_builder = CppBuilder( |
| | name=str(consts_s.stem), |
| | sources=str(consts_s), |
| | output_dir=str(consts_s.parent), |
| | BuildOption=object_build_options, |
| | ) |
| | consts_o = object_builder.get_target_file_path() |
| | if use_asm_build is False and is_zero_size_consts: |
| | run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent)) |
| | else: |
| | object_builder.build() |
| |
|
| | if is_large_consts and use_asm_build: |
| | with open(consts_o, "r+b") as f: |
| | f.seek(0) |
| | hdr = f.read(1024) |
| | |
| | start_idx = ( |
| | hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") |
| | if sys.byteorder == "little" |
| | else hdr.find(b"\x12\x34\x56\x78\x99\xab\xcd\xef") |
| | ) |
| | assert start_idx != -1 |
| | f.seek(start_idx) |
| | pos = 0 |
| | while pos < len(consts): |
| | rc = f.write(consts[pos:]) |
| | pos += rc |
| |
|
| | |
| | os.remove(consts_s) |
| |
|
| | return consts_o |
| |
|
| | from torch.utils._filelock import FileLock |
| |
|
| | lock_dir = get_lock_dir() |
| | lock = FileLock( |
| | os.path.join(lock_dir, wrapper_key + ".lock"), timeout=LOCK_TIMEOUT |
| | ) |
| | with lock: |
| | if serialized_extern_kernel_nodes: |
| | extern_kernel_nodes_json = str( |
| | wrapper_path_operator.with_suffix(".json") |
| | ) |
| | with open(extern_kernel_nodes_json, "w") as f: |
| | f.write(serialized_extern_kernel_nodes) |
| |
|
| | if config.aot_inductor.package: |
| | generated_files.append(extern_kernel_nodes_json) |
| |
|
| | metadata = config.aot_inductor.metadata |
| | metadata["AOTI_DEVICE_KEY"] = device_type |
| |
|
| | |
| | meta_json = str( |
| | wrapper_path_operator.with_name( |
| | f"{wrapper_path_operator.stem}_metadata.json" |
| | ) |
| | ) |
| | for k, v in config.aot_inductor.metadata.items(): |
| | assert isinstance(k, str) and isinstance(v, (str)), ( |
| | "Metadata must only contain strings" |
| | ) |
| |
|
| | with open(meta_json, "w") as f: |
| | f.write(json.dumps(config.aot_inductor.metadata)) |
| |
|
| | kernel_meta_json = str( |
| | kernel_path_operator.with_name( |
| | f"{kernel_path_operator.stem}_metadata.json" |
| | ) |
| | ) |
| | shutil.copy(meta_json, kernel_meta_json) |
| |
|
| | if config.aot_inductor.package: |
| | generated_files.append(meta_json) |
| | if not config.aot_inductor.package_cpp_only: |
| | generated_files.append(kernel_meta_json) |
| |
|
| | output_so = ( |
| | config.aot_inductor.output_path |
| | if specified_artifact_name |
| | else str(wrapper_path_operator.with_suffix(".so")) |
| | ) |
| | all_cuda = all( |
| | graph.get_original_value_of_constant(name).is_cuda |
| | for name in graph.constants.keys() |
| | if name not in graph.folded_constants |
| | ) |
| |
|
| | def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: |
| | def _pad_to_alignment(raw_bytes: bytes) -> bytes: |
| | padded_bytes = raw_bytes.ljust( |
| | (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, |
| | b"\x00", |
| | ) |
| | return padded_bytes |
| |
|
| | |
| | |
| | import ctypes |
| |
|
| | if t.numel() == 0: |
| | return b"" |
| |
|
| | if t.is_mkldnn: |
| | data_ptr = torch.ops.mkldnn.data_ptr(t) |
| | nbytes = torch.ops.mkldnn._nbytes(t) |
| | else: |
| | t_cpu = t.untyped_storage().cpu() |
| | data_ptr = t_cpu.data_ptr() |
| | nbytes = t_cpu.nbytes() |
| |
|
| | raw_array = ctypes.cast( |
| | data_ptr, |
| | ctypes.POINTER(ctypes.c_ubyte * nbytes), |
| | ) |
| | raw_bytes = bytes(raw_array.contents) |
| | return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) |
| |
|
| | if config.aot_inductor.package_constants_in_so: |
| | serialized_weights = b"".join( |
| | _to_bytes(graph.get_original_value_of_constant(name), all_cuda) |
| | for name in graph.constants.keys() |
| | if name not in graph.folded_constants |
| | ) |
| | else: |
| | serialized_weights = b"" |
| |
|
| | if config.aot_inductor.package_constants_on_disk: |
| | |
| | weights_dict = Weights( |
| | { |
| | graph.allocated_constant_name[name]: ( |
| | graph.get_original_value_of_constant(name), |
| | TensorProperties(graph.constants[name]), |
| | ) |
| | for name in graph.constants.keys() |
| | if name not in graph.folded_constants |
| | } |
| | ) |
| | generated_files.append(weights_dict) |
| |
|
| | consts_size = len(serialized_weights) |
| |
|
| | |
| | use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000 |
| | if config.aot_inductor.force_mmap_weights: |
| | use_mmap_weights = True |
| |
|
| | compile_command: dict[str, Any] = { |
| | "aot_mode": graph.aot_mode, |
| | "device_type": device_type, |
| | "use_mmap_weights": use_mmap_weights, |
| | "use_relative_path": use_relative_path, |
| | "vec_isa": picked_vec_isa, |
| | } |
| | |
| | wrapper_build_options = CppTorchDeviceOptions( |
| | compile_only=True, |
| | min_optimize=not config.aot_inductor.package_cpp_only, |
| | **compile_command, |
| | ) |
| | kernel_build_options = CppTorchDeviceOptions( |
| | compile_only=True, |
| | **compile_command, |
| | ) |
| |
|
| | |
| | if config.aot_inductor.precompile_headers and not _IS_WINDOWS: |
| | header_file = _get_cpp_wrapper_header( |
| | device_type, aot_mode=graph.aot_mode |
| | ) |
| | wrapper_build_options.precompiled_header = _precompile_header( |
| | header_file, |
| | cpp_command, |
| | min_optimize=not config.aot_inductor.package_cpp_only, |
| | **compile_command, |
| | ) |
| | if cpp_prefix := _get_cpp_prefix_header(device_type): |
| | kernel_build_options.precompiled_header = _precompile_header( |
| | cpp_prefix, |
| | cpp_command, |
| | **compile_command, |
| | ) |
| |
|
| | wrapper_builder = CppBuilder( |
| | name=str(wrapper_path_operator.stem), |
| | sources=wrapper_path, |
| | output_dir=str(wrapper_path_operator.parent), |
| | BuildOption=wrapper_build_options, |
| | ) |
| | wrapper_compile_cmd = wrapper_builder.get_command_line() |
| | wrapper_o = wrapper_builder.get_target_file_path() |
| |
|
| | kernel_builder = CppBuilder( |
| | name=str(kernel_path_operator.stem), |
| | sources=kernel_path, |
| | output_dir=str(wrapper_path_operator.parent), |
| | BuildOption=kernel_build_options, |
| | ) |
| | kernel_compile_cmd = kernel_builder.get_command_line() |
| | kernel_o = kernel_builder.get_target_file_path() |
| |
|
| | log.debug("aot wrapper compilation command: %s", wrapper_compile_cmd) |
| | log.debug("aot kernel compilation command: %s", kernel_compile_cmd) |
| | if config.aot_inductor.package_cpp_only: |
| | |
| | compile_flags = str( |
| | wrapper_path_operator.with_name( |
| | f"{wrapper_path_operator.stem}_compile_flags.json" |
| | ) |
| | ) |
| | wrapper_build_options.save_flags_to_json(compile_flags) |
| | generated_files.append(compile_flags) |
| | wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type) |
| | wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path) |
| | generated_files.append(cmake_path) |
| | else: |
| | try: |
| | wrapper_builder.build() |
| | except (exc.CppCompileError, SkipFrame) as e: |
| | if " is too big to optimize" in str(e): |
| | raise RuntimeError( |
| | "Please use torch._inductor.config.aot_inductor.compile_wrapper_opt_level = 'O0' flag." |
| | ) from e |
| | raise e |
| | kernel_builder.build() |
| |
|
| | if not use_mmap_weights: |
| | aot_constants = serialized_weights |
| | magic_number = 0 |
| | else: |
| | magic_number = cast( |
| | int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() |
| | ) |
| | aot_constants = struct.pack("qq", consts_size + 8, magic_number) |
| |
|
| | consts_o = _compile_consts(aot_constants, sys.platform) |
| | custom_obj_idx = 0 |
| | |
| | |
| | |
| | |
| |
|
| | qual_name_to_id = {} |
| | for custom_obj_idx, (name, constant) in enumerate( |
| | graph.torchbind_constants.items() |
| | ): |
| | if isinstance( |
| | constant, torch._library.fake_class_registry.FakeScriptObject |
| | ): |
| | constant = constant.real_obj |
| | assert isinstance(constant, torch._C.ScriptObject) |
| | custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}" |
| |
|
| | log.debug("saving script object %s as %s", name, custom_obj_name) |
| |
|
| | qual_name_to_id[name] = custom_obj_name |
| | custom_obj_bytes = torch._C._pickle_save(constant) |
| | custom_obj_path = os.path.join( |
| | wrapper_path_operator.parent, custom_obj_name |
| | ) |
| |
|
| | write_atomic(custom_obj_path, custom_obj_bytes, True) |
| | generated_files.append(custom_obj_path) |
| |
|
| | if qual_name_to_id: |
| | constants_config_json = os.path.join( |
| | wrapper_path_operator.parent, "custom_objs_config.json" |
| | ) |
| | with open(constants_config_json, "w") as f: |
| | f.write(json.dumps(qual_name_to_id)) |
| | generated_files.append(constants_config_json) |
| |
|
| | gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = ( |
| | ROCmCodeCache() if torch.version.hip else CUDACodeCache() |
| | ) |
| | gpu_kernels_o = gpu_codecache.aot_kernels_o.copy() |
| | |
| | gpu_codecache.aot_kernels_o.clear() |
| |
|
| | if gpu_kernels_o: |
| | assert not config.aot_inductor.emit_multi_arch_kernel, ( |
| | "TODO: add emit_multi_arch_kernel support for cutlass kernels" |
| | ) |
| |
|
| | cubins_o = [] |
| | asm_files = [] |
| | if not _IS_WINDOWS: |
| | ld, objcopy = get_ld_and_objcopy(use_relative_path) |
| | kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) |
| | for kernel_name, value in CudaKernelParamCache.cache.items(): |
| | if kernel_name not in kernels: |
| | |
| | |
| | continue |
| |
|
| | if asm_file := value["asm"]: |
| | asm_files.append(asm_file) |
| |
|
| | cubin_file = value[get_cpp_wrapper_cubin_path_name()] |
| | if ( |
| | config.aot_inductor.emit_multi_arch_kernel |
| | and device_type == "cuda" |
| | ): |
| | current_arch = _nvcc_arch_as_compile_option() |
| | cmd = ( |
| | f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " |
| | |
| | f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " |
| | |
| | f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " |
| | ) |
| | try: |
| | subprocess.run( |
| | cmd.split(), |
| | capture_output=True, |
| | text=True, |
| | check=True, |
| | ) |
| | except subprocess.CalledProcessError as e: |
| | print( |
| | f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", |
| | file=sys.stderr, |
| | ) |
| | raise |
| |
|
| | if config.aot_inductor.embed_kernel_binary: |
| | |
| | cubins_o.append( |
| | convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) |
| | ) |
| |
|
| | output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) |
| | so_build_options = CppTorchDeviceOptions( |
| | vec_isa=picked_vec_isa, |
| | device_type=device_type, |
| | aot_mode=graph.aot_mode, |
| | use_relative_path=use_relative_path, |
| | ) |
| |
|
| | obj_srcs = [wrapper_o, kernel_o, consts_o, *gpu_kernels_o, *cubins_o] |
| | so_builder = CppBuilder( |
| | name=output_name, |
| | sources=obj_srcs, |
| | output_dir=output_dir, |
| | BuildOption=so_build_options, |
| | ) |
| | link_cmd = so_builder.get_command_line() |
| | output_so = so_builder.get_target_file_path() |
| |
|
| | log.debug("aot linkage command: %s", link_cmd) |
| |
|
| | |
| | with open(wrapper_path, "a") as f: |
| | f.write("\n") |
| | f.write(f"// Compile cmd\n// {wrapper_compile_cmd}\n") |
| | f.write(f"// Link cmd\n// {link_cmd}\n") |
| |
|
| | with open(kernel_path, "a") as f: |
| | f.write("\n") |
| | f.write(f"// Compile cmd\n// {kernel_compile_cmd}\n") |
| | f.write(f"// Link cmd\n// {link_cmd}\n") |
| |
|
| | if config.aot_inductor.package_cpp_only: |
| | linker_flags = str( |
| | wrapper_path_operator.with_name( |
| | f"{wrapper_path_operator.stem}_linker_flags.json" |
| | ) |
| | ) |
| | so_build_options.save_flags_to_json(linker_flags) |
| | generated_files.append(linker_flags) |
| | generated_files.append(_LINKER_SCRIPT) |
| |
|
| | |
| | |
| | if use_mmap_weights: |
| | weight_file = str( |
| | wrapper_path_operator.with_name( |
| | f"{wrapper_path_operator.stem}_serialized_weights.bin" |
| | ) |
| | ) |
| | with open(weight_file, "wb") as f_weights: |
| | f_weights.write(serialized_weights) |
| | f_weights.write(struct.pack("q", magic_number)) |
| |
|
| | generated_files.append(weight_file) |
| | else: |
| | |
| | generated_files.append(consts_o) |
| | so_builder.save_src_to_cmake(cmake_path, consts_o) |
| |
|
| | if config.aot_inductor.emit_multi_arch_kernel: |
| | so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files) |
| | generated_files.extend(asm_files) |
| | else: |
| | obj_srcs = [*gpu_kernels_o, *cubins_o] |
| | generated_files.extend(obj_srcs) |
| | for obj in obj_srcs: |
| | so_builder.save_src_to_cmake(cmake_path, obj) |
| |
|
| | so_builder.save_link_cmd_to_cmake(cmake_path) |
| | else: |
| | so_builder.build() |
| | for o_file in obj_srcs: |
| | if o_file in gpu_kernels_o: |
| | continue |
| | |
| | os.remove(o_file) |
| |
|
| | if use_mmap_weights: |
| |
|
| | def get_page_size() -> int: |
| | |
| | |
| | if _IS_WINDOWS: |
| | from ctypes import ( |
| | byref, |
| | Structure, |
| | windll, |
| | ) |
| | from ctypes.wintypes import DWORD, LPVOID, WORD |
| |
|
| | class SYSTEM_INFO(Structure): |
| | _fields_ = [ |
| | ("wProcessorArchitecture", WORD), |
| | ("wReserved", WORD), |
| | ("dwPageSize", DWORD), |
| | ("lpMinimumApplicationAddress", LPVOID), |
| | ("lpMaximumApplicationAddress", LPVOID), |
| | ("dwActiveProcessorMask", DWORD), |
| | ("dwNumberOfProcessors", DWORD), |
| | ("dwProcessorType", DWORD), |
| | ("dwAllocationGranularity", DWORD), |
| | ("wProcessorLevel", WORD), |
| | ("wProcessorRevision", WORD), |
| | ] |
| |
|
| | si = SYSTEM_INFO() |
| | windll.kernel32.GetSystemInfo(byref(si)) |
| | sys_page_size = si.dwPageSize |
| | else: |
| | import resource |
| |
|
| | sys_page_size = resource.getpagesize() |
| |
|
| | return sys_page_size |
| |
|
| | page_size_ = get_page_size() |
| | page_size = max(16384, page_size_) |
| |
|
| | with open(output_so, "a+b") as f_so: |
| | so_size = f_so.tell() |
| | |
| | f_so.write(b" " * (page_size - so_size % page_size)) |
| | f_so.write(serialized_weights) |
| | f_so.write(struct.pack("q", magic_number)) |
| |
|
| | if config.aot_inductor.package: |
| | generated_files.append(output_so) |
| |
|
| | if config.aot_inductor.package: |
| | if config.trace.provenance_tracking_level != 0: |
| | kernel_info = torch._inductor.debug.create_kernel_information_json() |
| | kernel_info_json = os.path.join( |
| | wrapper_path_operator.parent, "kernel_information.json" |
| | ) |
| | with open(kernel_info_json, "w") as f: |
| | f.write(json.dumps(kernel_info, indent=4)) |
| | generated_files.append(kernel_info_json) |
| |
|
| | |
| | |
| | |
| | return generated_files |
| |
|
| | return output_so |
| |
|
| |
|
| | _libgomp: Optional[CDLL] = None |
| |
|
| |
|
| | def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]: |
| | |
| | |
| | def convert_arg(arg: Any) -> Any: |
| | if str(type(arg)) == "<class 'PyCapsule'>": |
| | |
| | return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) |
| | elif isinstance(arg, (list, tuple)): |
| | return type(arg)(convert_arg(a) for a in arg) |
| | else: |
| | return arg |
| |
|
| | converted_args = [convert_arg(arg) for arg in args] |
| |
|
| | assert op.startswith("torch.ops."), ( |
| | op + " can not be called through custom_op_wrapper" |
| | ) |
| | func = None |
| | for i, s in enumerate(op.split(".")): |
| | if i == 0: |
| | func = importlib.import_module(s) |
| | func = getattr(func, s) |
| |
|
| | assert callable(func), op + " can not be loaded through custom_op_wrapper" |
| |
|
| | |
| | kwargs = dict() |
| | for func_arg, conv_arg in zip(func._schema.arguments, converted_args): |
| | if func_arg.kwarg_only: |
| | kwargs[func_arg.name] = conv_arg |
| | if kwargs: |
| | del converted_args[-len(kwargs) :] |
| |
|
| | result = func(*converted_args, **kwargs) |
| | if result is None: |
| | return None |
| |
|
| | if isinstance(result, (list, tuple)): |
| | |
| | result = [torch.tensor([]) if r is None else r for r in result] |
| | for i, r in enumerate(result): |
| | assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" |
| | return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) |
| |
|
| | assert isinstance(result, torch.Tensor), op + " returns a non-tensor" |
| | return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) |
| |
|
| |
|
| | |
| | |
| | |
| | _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") |
| | _HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") |
| |
|
| |
|
| | @functools.cache |
| | def _precompile_header( |
| | header: str, |
| | hashable_cmd_line: str, |
| | **compile_command: Any, |
| | ) -> str: |
| | assert not _IS_WINDOWS, ( |
| | "CppBuilder does not currently support precompiling on Windows!" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | with tempfile.TemporaryDirectory() as preprocessing_dir: |
| | preprocessing_header = Path(preprocessing_dir) / "header.hpp" |
| | preprocessing_header.write_text(f"#include <{header}>\n") |
| | preprocessor = CppBuilder( |
| | name=str(preprocessing_header)[:-4], |
| | sources=str(preprocessing_header), |
| | BuildOption=CppTorchDeviceOptions(**compile_command, preprocessing=True), |
| | ) |
| | preprocessor.build() |
| |
|
| | def _get_file_checksum(filename: str) -> str: |
| | """Reading the whole preprocessed header in for hashing is very expensive, |
| | but calling a fast hashing utility in a subprocess is cheap.""" |
| | |
| | cmd_output = subprocess.run( |
| | ("openssl", "sha512", filename), capture_output=True, text=True |
| | ) |
| | return cmd_output.stdout.split()[-1] |
| |
|
| | preprocessor_hash = _get_file_checksum(preprocessor.get_target_file_path()) |
| |
|
| | header_build_option = CppTorchDeviceOptions(**compile_command, precompiling=True) |
| | header_hash, header_full_path = write( |
| | content=f"#include <{header}>\n", |
| | extension="h", |
| | extra=( |
| | hashable_cmd_line |
| | + preprocessor_hash |
| | + get_compiler_version_info(header_build_option.get_compiler()) |
| | ), |
| | specified_dir=_HEADER_DIR, |
| | ) |
| | cpp_builder = CppBuilder( |
| | name=header_full_path, |
| | sources=header_full_path, |
| | BuildOption=header_build_option, |
| | ) |
| | |
| | |
| | os.makedirs(_HEADER_LOCK_DIR, exist_ok=True) |
| | _worker_compile_cpp( |
| | os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"), |
| | (cpp_builder,), |
| | ) |
| |
|
| | return header_full_path |
| |
|
| |
|
| | def _get_cpp_prefix_header(device: str) -> Optional[str]: |
| | if device.startswith("cpu"): |
| | return "torch/csrc/inductor/cpp_prefix.h" |
| | return None |
| |
|
| |
|
| | def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: |
| | """Given a device type (and optionally whether we're in AOT Inductor mode), returns |
| | the path to the cpp_wrapper header file to be precompiled.""" |
| | base_device = device.split(":", maxsplit=1)[0] |
| | is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" |
| | return ( |
| | "torch/csrc/inductor/" |
| | f"{'aoti_include' if aot_mode else 'cpp_wrapper'}/" |
| | f"{'array_ref' if is_array_ref else base_device}.h" |
| | ) |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class CppCodeCache: |
| | """Compiles and caches C++ libraries. Users of this class supply the source code to |
| | be compiled, while compilation flags are set by CppBuilder.""" |
| |
|
| | cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} |
| | cache_clear = staticmethod(cache.clear) |
| | cpp_compile_command_flags: dict[str, Any] = {} |
| |
|
| | @staticmethod |
| | def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: |
| | return cdll.LoadLibrary(path) |
| |
|
| | @classmethod |
| | def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: |
| | try: |
| | result = cls._load_library_inner(path, key) |
| | result.key = key |
| | return result |
| | except (ImportError, OSError) as e: |
| | if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): |
| | |
| | global _libgomp |
| | _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") |
| | result = cls._load_library_inner(path, key) |
| | result.key = key |
| | return result |
| | if "failed to map segment from shared object" in str(e): |
| | raise OSError( |
| | f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " |
| | "is mounted with noexec (e.g., by default Docker mounts tmp file systems " |
| | f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " |
| | "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." |
| | ) from e |
| | raise |
| |
|
| | @classmethod |
| | def _get_uncompiled_header(cls, device: str) -> str | None: |
| | """ |
| | Given a device type, returns the path to a CPP header file to be precompiled. |
| | """ |
| | return None |
| |
|
| | @classmethod |
| | def load_async( |
| | cls, |
| | main_code: str, |
| | device_type: str = "cpu", |
| | submit_fn: Any = None, |
| | extra_flags: Sequence[str] = (), |
| | optimized_code: Optional[str] = None, |
| | ) -> Any: |
| | """Compile and load a C++ library. Returns a callable that returns the loaded |
| | library.""" |
| | compile_command = { |
| | **cls.cpp_compile_command_flags, |
| | "device_type": device_type, |
| | "extra_flags": extra_flags, |
| | "use_relative_path": config.is_fbcode(), |
| | "vec_isa": pick_vec_isa(), |
| | } |
| |
|
| | _set_gpu_runtime_env() |
| |
|
| | |
| | |
| | |
| | |
| | main_build_option = CppTorchDeviceOptions( |
| | compile_only=bool(optimized_code), |
| | min_optimize=optimized_code is not None, |
| | **compile_command, |
| | ) |
| | optimized_build_option = CppTorchDeviceOptions( |
| | compile_only=True, **compile_command |
| | ) |
| |
|
| | def get_hashable_command_line(build_option: BuildOptionsBase) -> str: |
| | """Writing the code to file will calculate a hash, which we need to vary if |
| | the command line flags change. This implements a mostly-generic way of |
| | validating that.""" |
| | return CppBuilder( |
| | name="o", sources="i", BuildOption=build_option |
| | ).get_command_line() |
| |
|
| | main_cmd_line = get_hashable_command_line(main_build_option) |
| | optimized_cmd_line = get_hashable_command_line(optimized_build_option) |
| |
|
| | key, main_path = write( |
| | main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}" |
| | ) |
| |
|
| | |
| | if optimized_code: |
| | _, optimized_path = write( |
| | optimized_code, "optimized.cpp", extra=optimized_cmd_line |
| | ) |
| | else: |
| | |
| | optimized_path = os.devnull |
| |
|
| | if key not in cls.cache: |
| | from torch.utils._filelock import FileLock |
| |
|
| | lock_path = os.path.join(get_lock_dir(), key + ".lock") |
| | future: Optional[Future[Any]] = None |
| | lib = None |
| |
|
| | |
| | if config.cpp_cache_precompile_headers and not _IS_WINDOWS: |
| | if header := cls._get_uncompiled_header(device_type): |
| | main_build_option.precompiled_header = _precompile_header( |
| | header, |
| | main_cmd_line, |
| | min_optimize=optimized_code is not None, |
| | **compile_command, |
| | ) |
| |
|
| | |
| | |
| | |
| | if optimized_code and (header := _get_cpp_prefix_header(device_type)): |
| | optimized_build_option.precompiled_header = _precompile_header( |
| | header, |
| | optimized_cmd_line, |
| | **compile_command, |
| | ) |
| |
|
| | main_name, output_dir = get_name_and_dir_from_output_file_path(main_path) |
| | main_builder = CppBuilder( |
| | name=main_name, |
| | sources=main_path, |
| | BuildOption=main_build_option, |
| | output_dir=output_dir, |
| | ) |
| |
|
| | if optimized_code: |
| | optimized_name, _ = get_name_and_dir_from_output_file_path( |
| | optimized_path |
| | ) |
| | optimized_builder = CppBuilder( |
| | name=optimized_name, |
| | sources=optimized_path, |
| | BuildOption=optimized_build_option, |
| | output_dir=output_dir, |
| | ) |
| |
|
| | linker = CppBuilder( |
| | name=main_name, |
| | sources=[ |
| | main_builder.get_target_file_path(), |
| | optimized_builder.get_target_file_path(), |
| | ], |
| | BuildOption=CppTorchDeviceOptions(**compile_command), |
| | output_dir=output_dir, |
| | ) |
| |
|
| | worker_fn = functools.partial( |
| | _worker_compile_cpp, |
| | lock_path, |
| | (main_builder, optimized_builder, linker), |
| | ) |
| | binary_path = normalize_path_separator(linker.get_target_file_path()) |
| | else: |
| | worker_fn = functools.partial( |
| | _worker_compile_cpp, lock_path, (main_builder,) |
| | ) |
| | binary_path = normalize_path_separator( |
| | main_builder.get_target_file_path() |
| | ) |
| |
|
| | def load_fn() -> Any: |
| | nonlocal lib |
| | if lib is None: |
| | if future is not None: |
| | future.result() |
| | result = worker_fn() |
| | assert result is None |
| | lib = cls._load_library(binary_path, key) |
| | assert lib is not None |
| | return lib |
| |
|
| | if submit_fn is not None: |
| | with FileLock(lock_path, timeout=LOCK_TIMEOUT): |
| | if not os.path.exists(binary_path): |
| | future = submit_fn(worker_fn) |
| |
|
| | cls.cache[key] = load_fn |
| |
|
| | return cls.cache[key] |
| |
|
| | @classmethod |
| | def load(cls, *args: Any, **kwargs: Any) -> Any: |
| | return cls.load_async(*args, **kwargs)() |
| |
|
| |
|
| | def _worker_compile_cpp( |
| | lock_path: str, |
| | cpp_builders: Sequence[CppBuilder], |
| | ) -> None: |
| | from torch.utils._filelock import FileLock |
| |
|
| | with FileLock(lock_path, timeout=LOCK_TIMEOUT): |
| | for builder in cpp_builders: |
| | if not os.path.exists(builder.get_target_file_path()): |
| | builder.build() |
| |
|
| |
|
| | |
| | @clear_on_fresh_cache |
| | class CppPythonBindingsCodeCache(CppCodeCache): |
| | cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} |
| | cache_clear = staticmethod(cache.clear) |
| | cpp_compile_command_flags = { |
| | |
| | "include_pytorch": False, |
| | "shared": True, |
| | } |
| | entry_function = "kernel" |
| | call_entry_function = "kernel({}); Py_RETURN_NONE;" |
| | extra_parse_arg = "" |
| | suffix_template = textwrap.dedent( |
| | """ |
| | // Python bindings to call {entry_func}(): |
| | #define PY_SSIZE_T_CLEAN |
| | #include <Python.h> |
| | #include <sstream> |
| | #include <cstdlib> |
| | |
| | #ifndef _MSC_VER |
| | #if __cplusplus < 202002L |
| | // C++20 (earlier) code |
| | // https://en.cppreference.com/w/cpp/language/attributes/likely |
| | #define likely(x) __builtin_expect(!!(x), 1) |
| | #define unlikely(x) __builtin_expect(!!(x), 0) |
| | #endif |
| | #else |
| | #define likely(x) (x) |
| | #define unlikely(x) (x) |
| | #endif |
| | |
| | // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. |
| | // We manually link it below to workaround issues with fbcode build. |
| | static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); |
| | |
| | template <typename T> static inline T parse_arg(PyObject* args, size_t n) {{ |
| | static_assert(std::is_pointer_v<T>, "arg type must be pointer or long"); |
| | return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); |
| | }} |
| | template <> inline int64_t parse_arg<int64_t>(PyObject* args, size_t n) {{ |
| | auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); |
| | if(unlikely(result == -1 && PyErr_Occurred())) |
| | throw std::runtime_error("expected int arg"); |
| | return result; |
| | }} |
| | template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) {{ |
| | auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); |
| | if(unlikely(result == reinterpret_cast<void*>(-1) && PyErr_Occurred())) |
| | throw std::runtime_error("expected int arg"); |
| | return reinterpret_cast<uintptr_t>(result); |
| | }} |
| | |
| | {extra_parse_arg} |
| | |
| | static PyObject* {entry_func}_py(PyObject* self, PyObject* args) {{ |
| | try {{ |
| | if(unlikely(!PyTuple_CheckExact(args))) |
| | throw std::runtime_error("tuple args required"); |
| | if(unlikely(PyTuple_GET_SIZE(args) != {arg_len})) |
| | throw std::runtime_error("requires {arg_len} args"); |
| | {call_entry_func} |
| | }} catch(std::exception const& e) {{ |
| | PyErr_SetString(PyExc_RuntimeError, e.what()); |
| | return nullptr; |
| | }} catch(...) {{ |
| | PyErr_SetString(PyExc_RuntimeError, "unhandled error"); |
| | return nullptr; |
| | }} |
| | }} |
| | |
| | static PyMethodDef py_methods[] = {{ |
| | {{"{entry_func}", {entry_func}_py, METH_VARARGS, ""}}, |
| | {{NULL, NULL, 0, NULL}}}}; |
| | |
| | static struct PyModuleDef py_module = |
| | {{PyModuleDef_HEAD_INIT, "{entry_func}", NULL, -1, py_methods}}; |
| | |
| | PyMODINIT_FUNC PyInit_{entry_func}(void) {{ |
| | const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); |
| | if(!str_addr) {{ |
| | PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); |
| | return nullptr; |
| | }} |
| | std::istringstream iss(str_addr); |
| | uintptr_t addr = 0; |
| | iss >> addr; |
| | _torchinductor_pyobject_tensor_data_ptr = |
| | reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr); |
| | PyObject* module = PyModule_Create(&py_module); |
| | if (module == NULL) {{ |
| | return NULL; |
| | }} |
| | #ifdef Py_GIL_DISABLED |
| | PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); |
| | #endif |
| | return module; |
| | }} |
| | """ |
| | ) |
| |
|
| | @classmethod |
| | def _load_library_inner(cls, path: str, key: str) -> ModuleType: |
| | os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( |
| | torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr |
| | ) |
| | module_name = f"{key}.{cls.entry_function}" |
| | try: |
| | return sys.modules[module_name] |
| | except KeyError: |
| | pass |
| | spec = importlib.util.spec_from_file_location(module_name, path) |
| | assert spec is not None |
| | module = importlib.util.module_from_spec(spec) |
| | sys.modules[module_name] = module |
| | assert spec.loader is not None |
| | spec.loader.exec_module(module) |
| | return module |
| |
|
| | @classmethod |
| | def _get_uncompiled_header(cls, device: str) -> str | None: |
| | return _get_cpp_prefix_header(device) |
| |
|
| | @classmethod |
| | def load_pybinding_async( |
| | cls, |
| | argtypes: Sequence[str], |
| | main_code: str, |
| | device_type: str = "cpu", |
| | num_outputs: int = -1, |
| | submit_fn: Any = None, |
| | extra_flags: Sequence[str] = (), |
| | kernel_code: Optional[str] = None, |
| | ) -> Any: |
| | """ |
| | Wrap a C++ function in fast Python bindings. |
| | |
| | Args: |
| | argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] |
| | main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at |
| | -O3 if kernel_code is None (to maximize performance in any kernels that |
| | are present), or -O1 otherwise (to minimize compile time). |
| | kernel_code: If present, C++ source code that will be built at -O3 and |
| | linked to main_code. |
| | |
| | Returns: |
| | A python version of ENTRY_FUNCTION() |
| | """ |
| | parseargs = ", ".join( |
| | f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" |
| | for n, argtype in enumerate(argtypes) |
| | ) |
| | suffix = cls.suffix_template.format( |
| | arg_len=len(argtypes), |
| | call_entry_func=cls.call_entry_function.format(parseargs), |
| | entry_func=cls.entry_function, |
| | extra_parse_arg=cls.extra_parse_arg.format(array_len=num_outputs), |
| | ) |
| | get_result = cls.load_async( |
| | main_code + suffix, |
| | device_type, |
| | submit_fn=submit_fn, |
| | extra_flags=extra_flags, |
| | optimized_code=kernel_code, |
| | ) |
| | result = None |
| |
|
| | def future() -> Any: |
| | nonlocal result |
| | if result is None: |
| | result = get_result() |
| | assert isinstance(result, ModuleType) |
| | return getattr(result, cls.entry_function) |
| |
|
| | return future |
| |
|
| | @classmethod |
| | def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: |
| | return cls.load_pybinding_async(*args, **kwargs)() |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class CppWrapperCodeCache(CppPythonBindingsCodeCache): |
| | cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} |
| | cache_clear = staticmethod(cache.clear) |
| | cpp_compile_command_flags = { |
| | "include_pytorch": True, |
| | "shared": True, |
| | } |
| | entry_function = "inductor_entry_cpp" |
| | call_entry_function = "return inductor_entry_cpp({});" |
| | extra_parse_arg = textwrap.dedent( |
| | """ |
| | #include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| | |
| | static inline std::vector<AtenTensorHandle> unpack_tensor_handle_list(PyObject* pyvec) {{ |
| | std::vector<AtenTensorHandle> result; |
| | size_t result_len = PyList_GET_SIZE(pyvec); |
| | result.reserve(result_len); |
| | for (size_t i = 0; i < result_len; i++) {{ |
| | // AtenTensorHandle is essentially a pointer |
| | void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); |
| | result.push_back(reinterpret_cast<AtenTensorHandle>(elem)); |
| | }} |
| | return result; |
| | }} |
| | |
| | static inline PyObject* pack_tensor_handle_list(const std::array<AtenTensorHandle, {array_len}>& arr) {{ |
| | PyObject* result = PyList_New({array_len}); |
| | for (size_t i = 0; i < {array_len}; i++) {{ |
| | PyObject *elem = |
| | arr[i] == nullptr |
| | ? Py_None |
| | // Store AtenTensorHandle as PyCapsulate |
| | : PyCapsule_New(reinterpret_cast<void*>(arr[i]), NULL, NULL); |
| | PyList_SET_ITEM(result, i, elem); |
| | }} |
| | return result; |
| | }} |
| | |
| | template <> inline std::vector<AtenTensorHandle> parse_arg<std::vector<AtenTensorHandle>>(PyObject* args, size_t n) {{ |
| | return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); |
| | }} |
| | |
| | PyObject* inductor_entry_cpp(std::vector<AtenTensorHandle>&& input_handles) {{ |
| | // For outputs, we only allocate an array to hold returned tensor handles, |
| | // not the actual output tensor storage. |
| | std::array<AtenTensorHandle, {array_len}> output_handles{{}}; |
| | try {{ |
| | inductor_entry_impl(input_handles.data(), output_handles.data()); |
| | if (PyErr_Occurred()) {{ |
| | return nullptr; |
| | }} |
| | return pack_tensor_handle_list(output_handles); |
| | }} catch(std::exception const& e) {{ |
| | PyErr_SetString(PyExc_RuntimeError, e.what()); |
| | return nullptr; |
| | }} catch(...) {{ |
| | PyErr_SetString(PyExc_RuntimeError, "unhandled error"); |
| | return nullptr; |
| | }} |
| | }} |
| | """ |
| | ) |
| |
|
| | @classmethod |
| | def _get_uncompiled_header(cls, device: str) -> str | None: |
| | return _get_cpp_wrapper_header(device) |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class HalideCodeCache(CppPythonBindingsCodeCache): |
| | cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} |
| | cache_clear = staticmethod(cache.clear) |
| | _standalone_runtime_path: Optional[str] = None |
| | prefix = textwrap.dedent( |
| | """ |
| | #include "{halideruntime_h}" |
| | #include "{headerfile}" |
| | #include <stdexcept> |
| | #include <cmath> |
| | |
| | namespace c10 {{ |
| | inline long div_floor_integer(long a, long b) {{ |
| | if ((a<0) != (b<0)) {{ |
| | const auto quot = a / b; |
| | const auto rem = a % b; |
| | return rem ? quot - 1 : quot; |
| | }} |
| | return a / b; |
| | }} |
| | }} |
| | """ |
| | ) |
| | glue_template_cpp = prefix + textwrap.dedent( |
| | """ |
| | void kernel({argdefs}) {{ |
| | {buffers} |
| | int err = halide_kernel({buffer_names}); |
| | if(err != 0) throw std::runtime_error("halide_kernel failed"); |
| | }} |
| | """ |
| | ) |
| | glue_template_cuda = prefix + textwrap.dedent( |
| | """ |
| | #include <cuda.h> |
| | static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface(); |
| | |
| | void kernel({argdefs}, uintptr_t stream) {{ |
| | {buffers} |
| | int err = halide_kernel(reinterpret_cast<void*>(stream), {buffer_names}); |
| | if(err != 0) throw std::runtime_error("halide_kernel failed"); |
| | }} |
| | """ |
| | ) |
| | standalone_runtime_cuda_init = textwrap.dedent( |
| | """ |
| | #include "{}" |
| | #include <cuda.h> |
| | |
| | static int acquire_context(void* user_context, |
| | void** cuda_context_out, |
| | bool create) {{ |
| | return cuCtxGetCurrent(reinterpret_cast<CUcontext*>(cuda_context_out)); |
| | }} |
| | |
| | static int release_context(void* user_context) {{ |
| | return 0; |
| | }} |
| | |
| | static int get_stream(void* user_context, |
| | void* cuda_context, |
| | void** stream_out) {{ |
| | *stream_out = user_context; |
| | return 0; |
| | }} |
| | |
| | static int register_halide_hooks() {{ |
| | halide_set_cuda_acquire_context(&acquire_context); |
| | halide_set_cuda_release_context(&release_context); |
| | halide_set_cuda_get_stream(&get_stream); |
| | return 0; |
| | }} |
| | |
| | int inductor_register_halide_hooks_result = register_halide_hooks(); |
| | """ |
| | ) |
| |
|
| | @classmethod |
| | def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]: |
| | assert arg.shape is not None |
| | assert arg.stride is not None and len(arg.shape) == len(arg.stride) |
| | assert arg.offset is not None |
| | data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}" |
| | if cuda: |
| | device = f"reinterpret_cast<uint64_t>({data_ptr})" |
| | device_interface = "cuda_interface" |
| | host = "nullptr" |
| | flags = "halide_buffer_flag_device_dirty" |
| | else: |
| | device = "0" |
| | device_interface = "nullptr" |
| | host = f"reinterpret_cast<uint8_t*>({data_ptr})" |
| | flags = "halide_buffer_flag_host_dirty" |
| |
|
| | dims = [] |
| | for size, stride in zip(arg.shape, arg.stride): |
| | dims.append(f"halide_dimension_t(0, {size}, {stride})") |
| |
|
| | return [ |
| | f"halide_buffer_t {name};", |
| | f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};" |
| | if len(dims) > 0 |
| | else f"halide_dimension_t * {name}_dims = nullptr;", |
| | f"{name}.device = {device};", |
| | f"{name}.device_interface = {device_interface};", |
| | f"{name}.host = {host};", |
| | f"{name}.flags = {flags};", |
| | f"{name}.type = {arg.halide_type()};", |
| | f"{name}.dimensions = {len(dims)};", |
| | f"{name}.dim = {name}_dims;", |
| | f"{name}.padding = nullptr;", |
| | ] |
| |
|
| | @classmethod |
| | def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: |
| | is_cuda = meta.is_cuda() |
| | assert is_cuda is ("user_context" in meta.target) |
| | assert "no_runtime" in meta.target |
| | buffers = [] |
| | buffer_names = [] |
| | for i, arg in enumerate(meta.argtypes): |
| | if arg.is_buffer(): |
| | buffer_names.append(f"&hl_buf_{i}") |
| | buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) |
| | else: |
| | assert "*" not in arg.ctype |
| | buffer_names.append(arg.name) |
| | buffers = "\n".join([f" {line}" for line in buffers]).lstrip() |
| |
|
| | glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp |
| | glue_code = glue_template.format( |
| | halideruntime_h=cls.find_header( |
| | "HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h" |
| | ), |
| | headerfile=headerfile, |
| | argdefs=", ".join( |
| | f"{a.bindings_type()} {a.name}" |
| | for a in meta.argtypes |
| | if a.alias_of is None |
| | ), |
| | buffers=buffers, |
| | buffer_names=", ".join(buffer_names), |
| | ) |
| | return glue_code |
| |
|
| | @classmethod |
| | @functools.cache |
| | def config_hash(cls) -> str: |
| | command_gen = CppBuilder( |
| | name="O", |
| | sources="I", |
| | BuildOption=CppOptions(), |
| | ) |
| | command_line = command_gen.get_command_line() |
| | return sha256_hash( |
| | "\n".join( |
| | [ |
| | cls.glue_template_cpp, |
| | cls.glue_template_cuda, |
| | cls.standalone_runtime_cuda_init, |
| | command_line, |
| | ] |
| | ).encode("utf-8") |
| | ) |
| |
|
| | @staticmethod |
| | def _search_for_file(suffix: str, errmsg: str) -> str: |
| | spec = importlib.machinery.PathFinder.find_spec("halide") |
| | if spec is None or not spec.submodule_search_locations: |
| | raise RuntimeError("halide python bindings not installed") |
| | try: |
| | search = spec.submodule_search_locations[0] |
| | for file in os.listdir(search): |
| | if file.endswith(".so"): |
| | try: |
| | out = subprocess.check_output( |
| | ["ldd", os.path.join(search, file)] |
| | ) |
| | except subprocess.SubprocessError: |
| | continue |
| | m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) |
| | if m: |
| | path = os.path.join(os.path.abspath(m.group(1)), suffix) |
| | if os.path.exists(path): |
| | return os.path.abspath(path) |
| | except Exception as e: |
| | raise RuntimeError(errmsg) from e |
| | raise RuntimeError(errmsg) |
| |
|
| | @staticmethod |
| | @functools.cache |
| | def find_libautoschedule(name: str) -> str: |
| | sofile = f"libautoschedule_{name.lower()}.so" |
| | if "HALIDE_LIB" in os.environ: |
| | path = os.path.join(os.environ["HALIDE_LIB"], sofile) |
| | if os.path.exists(path): |
| | return path |
| | errmsg = ( |
| | f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" |
| | ) |
| | return HalideCodeCache._search_for_file(sofile, errmsg) |
| |
|
| | @staticmethod |
| | @functools.cache |
| | def find_header(name: str) -> str: |
| | if "HALIDE_INCLUDE" in os.environ: |
| | path = os.path.join(os.environ["HALIDE_INCLUDE"], name) |
| | if os.path.exists(path): |
| | return path |
| | if "HALIDE_LIB" in os.environ: |
| | path = os.path.abspath( |
| | os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") |
| | ) |
| | if os.path.exists(path): |
| | return path |
| | errmsg = ( |
| | f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" |
| | ) |
| | return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) |
| |
|
| | @classmethod |
| | def generate_halide_async( |
| | cls, meta: HalideMeta, source_code: str, submit_fn: Any = None |
| | ) -> Callable[[], Any]: |
| | dirpath = Path( |
| | get_path( |
| | code_hash( |
| | source_code, |
| | extra=repr((cls.config_hash(), meta)), |
| | ), |
| | "halide", |
| | )[2] |
| | ) |
| | os.makedirs(dirpath, exist_ok=True) |
| | wait_for_compile = None |
| | genfile = str(dirpath / "generate_kernel.py") |
| | libfile = str(dirpath / "halide_kernel.a") |
| | headerfile = str(dirpath / "halide_kernel.h") |
| | donefile = str(dirpath / "done") |
| | lockfile = str(dirpath / "lock") |
| | need_compile = not os.path.exists(donefile) |
| | jobs: list[Any] = [] |
| | if need_compile: |
| | write_atomic(genfile, source_code) |
| | cmd = [ |
| | sys.executable, |
| | genfile, |
| | "-g", |
| | "kernel", |
| | "-o", |
| | f"{dirpath}", |
| | "-f", |
| | "halide_kernel", |
| | "-e", |
| | "static_library,h,schedule", |
| | ] |
| | if meta.scheduler: |
| | cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)]) |
| | cmd.extend(meta.args()) |
| | jobs.append(functools.partial(subprocess.check_call, cmd)) |
| |
|
| | binding_types = [ |
| | arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None |
| | ] |
| | if meta.is_cuda(): |
| | binding_types.append("uintptr_t") |
| | bindings_future = cls.load_pybinding_async( |
| | binding_types, |
| | cls._codegen_glue(meta, headerfile), |
| | extra_flags=(libfile, cls.build_standalone_runtime()), |
| | submit_fn=jobs.append if need_compile else None, |
| | device_type="cuda" if meta.is_cuda() else "cpu", |
| | ) |
| |
|
| | if need_compile: |
| | jobs.append(functools.partial(touch, donefile)) |
| | task = functools.partial(_worker_task_halide, lockfile, jobs) |
| | if submit_fn: |
| | wait_for_compile = submit_fn(task).result |
| | else: |
| | task() |
| |
|
| | def load() -> Callable[[], Any]: |
| | if wait_for_compile: |
| | wait_for_compile() |
| | return bindings_future() |
| |
|
| | return load |
| |
|
| | @classmethod |
| | def generate_halide(cls, *args: Any, **kwargs: Any) -> Callable[[], Any]: |
| | return cls.generate_halide_async(*args, **kwargs)() |
| |
|
| | @classmethod |
| | def build_standalone_runtime(cls) -> str: |
| | if cls._standalone_runtime_path and os.path.exists( |
| | cls._standalone_runtime_path |
| | ): |
| | return cls._standalone_runtime_path |
| | device_type = "cuda" if torch.cuda.is_available() else "cpu" |
| | libname = "libStandaloneHalideRuntime.so" |
| | target = "host-cuda" if device_type == "cuda" else "host" |
| | if cls._standalone_runtime_path: |
| | assert not os.path.exists(cls._standalone_runtime_path) |
| | |
| | |
| | |
| | |
| | base = default_cache_dir() |
| | else: |
| | base = cache_dir() |
| | dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" |
| | os.makedirs(dirpath, exist_ok=True) |
| | done_file = str(dirpath / "done") |
| | lock_file = str(dirpath / "lock") |
| | hook_file = str(dirpath / "hooks.cpp") |
| | a_file = str(dirpath / "standalone_halide_runtime.a") |
| | so_file = str(dirpath / libname) |
| | if not os.path.exists(done_file): |
| | import halide as hl |
| |
|
| | from torch.utils._filelock import FileLock |
| |
|
| | with FileLock(lock_file, LOCK_TIMEOUT): |
| | if not os.path.exists(done_file): |
| | with open(hook_file, "w") as f: |
| | if device_type == "cuda": |
| | f.write( |
| | cls.standalone_runtime_cuda_init.format( |
| | cls.find_header("HalideRuntimeCuda.h") |
| | ) |
| | ) |
| | hl.compile_standalone_runtime(a_file, hl.Target(target)) |
| |
|
| | name, output_dir = get_name_and_dir_from_output_file_path(so_file) |
| | halide_cmd_gen = CppBuilder( |
| | name=name, |
| | sources=[hook_file, a_file], |
| | output_dir=output_dir, |
| | BuildOption=CppTorchDeviceOptions( |
| | device_type=device_type, |
| | ), |
| | ) |
| |
|
| | subprocess.check_call( |
| | shlex.split(halide_cmd_gen.get_command_line()) |
| | ) |
| | touch(done_file) |
| | assert os.path.exists(so_file) |
| | cls._standalone_runtime_path = so_file |
| | return so_file |
| |
|
| | @classmethod |
| | def _get_uncompiled_header(cls, device: str) -> str | None: |
| | """Header precompiling is currently disabled for halide.""" |
| | return None |
| |
|
| |
|
| | def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None: |
| | from torch.utils._filelock import FileLock |
| |
|
| | try: |
| | with FileLock(lockfile, LOCK_TIMEOUT): |
| | for job in jobs: |
| | job() |
| | except subprocess.SubprocessError as e: |
| | if os.environ.get("HALIDE_REPRO") == "1": |
| | cmd: list[Any] |
| | python, script, *cmd = getattr(e, "cmd", ("", "", "")) |
| | if os.path.basename(python).startswith("python"): |
| | code = open(script).read() |
| | main = " hl.main()" |
| | assert code.count(main) == 1 |
| |
|
| | class Out: |
| | def __repr__(self) -> str: |
| | return "out" |
| |
|
| | ci = cmd.index("-o") |
| | assert isinstance(ci, int) |
| | cmd[ci + 1] = Out() |
| | repl = textwrap.indent( |
| | textwrap.dedent( |
| | f"""\ |
| | import sys, tempfile |
| | with tempfile.TemporaryDirectory() as out: |
| | sys.argv = {["repro.py", *cmd]!r} |
| | hl.main() |
| | """ |
| | ), |
| | " ", |
| | ) |
| | code = code.replace(main, repl) |
| | with open("repro.py", "w") as fd: |
| | fd.write(code.lstrip()) |
| | raise RuntimeError(f"wrote repro.py: {e}") from e |
| | raise |
| |
|
| |
|
| | def touch(filename: str) -> None: |
| | open(filename, "a").close() |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class PyCodeCache: |
| | |
| | |
| | |
| | |
| | modules: list[ModuleType] = [] |
| |
|
| | |
| | |
| | modules_no_attr: dict[str, ModuleType] = {} |
| |
|
| | linemaps: dict[str, list[tuple[Any, ...]]] = {} |
| |
|
| | @classmethod |
| | def write(cls, source_code: str, extra: str = "") -> tuple[str, str]: |
| | return write(source_code, "py", extra=extra) |
| |
|
| | @classmethod |
| | def load(cls, source_code: str, extra: str = "") -> ModuleType: |
| | key, path = write(source_code, "py", extra=extra) |
| | return cls.load_by_key_path(key, path) |
| |
|
| | @classmethod |
| | def load_by_key_path( |
| | cls, |
| | key: str, |
| | path: str, |
| | linemap: Optional[list[tuple[int, str]]] = None, |
| | attrs: Optional[dict[str, Any]] = None, |
| | ) -> ModuleType: |
| | if linemap is None: |
| | linemap = [] |
| |
|
| | |
| | if attrs is None and path in cls.modules_no_attr: |
| | return cls.modules_no_attr[path] |
| |
|
| | in_toplevel = in_toplevel_process() |
| | mod = _reload_python_module(key, path, set_sys_modules=in_toplevel) |
| |
|
| | |
| | if in_toplevel: |
| | cls.linemaps[path] = list(zip(*linemap)) |
| |
|
| | if attrs is not None: |
| | for k, v in attrs.items(): |
| | setattr(mod, k, v) |
| |
|
| | if in_toplevel: |
| | |
| | if attrs is None: |
| | cls.modules_no_attr[path] = mod |
| |
|
| | cls.modules.append(mod) |
| | return mod |
| |
|
| | @classmethod |
| | def cache_clear(cls, purge: bool = False) -> None: |
| | """ |
| | Clear the in-memory module cache. If purge=True, also delete all the |
| | corresponding on-disk source files. |
| | """ |
| | if purge: |
| | for mod in cls.modules: |
| | try: |
| | assert mod.__file__ |
| | os.remove(mod.__file__) |
| | except FileNotFoundError: |
| | pass |
| | cls.modules.clear() |
| | cls.modules_no_attr.clear() |
| |
|
| | @classmethod |
| | @functools.cache |
| | def stack_frames_for_code( |
| | cls, path: str, lineno: int |
| | ) -> Optional[list[dict[str, Any]]]: |
| | if path not in cls.linemaps: |
| | return None |
| | if len(cls.linemaps[path]) == 0: |
| | return None |
| | |
| | lines, nodes = cls.linemaps[path] |
| | p = bisect_right(lines, lineno) |
| | if p == 0: |
| | return None |
| | entry = nodes[p - 1] |
| | if not entry: |
| | return None |
| |
|
| | def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]: |
| | |
| | |
| | regex = r'File "(.+)", line (\d+), in (.+)\n' |
| | matches = re.findall(regex, stack_trace) |
| | return [ |
| | {"filename": f, "line": int(l), "name": n} |
| | for f, l, n in reversed(matches) |
| | ] |
| |
|
| | return parse_stack_trace(entry) |
| |
|
| |
|
| | def _load_triton_kernel_from_source( |
| | kernel_name: str, source_code: str |
| | ) -> CachingAutotuner: |
| | return getattr(PyCodeCache.load(source_code), kernel_name) |
| |
|
| |
|
| | def _cuda_compiler() -> Optional[str]: |
| | if cuda_env.nvcc_exist(config.cuda.cuda_cxx): |
| | return config.cuda.cuda_cxx |
| | if config.is_fbcode(): |
| | return os.path.join(build_paths.sdk_home, "bin", "nvcc") |
| | if cuda_env.nvcc_exist(os.getenv("CUDACXX")): |
| | return os.getenv("CUDACXX", "") |
| | if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): |
| | return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) |
| | return "nvcc" |
| |
|
| |
|
| | def _cutlass_path() -> str: |
| | if config.is_fbcode(): |
| | from libfb.py import parutil |
| |
|
| | return parutil.get_dir_path("cutlass-4-headers") |
| | else: |
| | return config.cuda.cutlass_dir |
| |
|
| |
|
| | def _cutlass_paths() -> list[str]: |
| | return [ |
| | "include", |
| | "tools/library/include", |
| | "tools/library/src", |
| | "tools/util/include", |
| | ] |
| |
|
| |
|
| | def _clone_cutlass_paths(build_root: str) -> list[str]: |
| | paths = _cutlass_paths() |
| | cutlass_root = _cutlass_path() |
| | for path in _cutlass_paths(): |
| | old_path = os.path.join(cutlass_root, path) |
| | new_path = os.path.join(build_root, path) |
| | shutil.copytree(old_path, new_path, dirs_exist_ok=True) |
| | return paths |
| |
|
| |
|
| | def _cutlass_include_paths() -> list[str]: |
| | cutlass_path = _cutlass_path() |
| | return [ |
| | |
| | os.path.realpath(os.path.join(cutlass_path, path)) |
| | for path in _cutlass_paths() |
| | ] |
| |
|
| |
|
| | @torch_key_cache |
| | def cutlass_key() -> bytes: |
| | """ |
| | Compute a key representing the state of the CUTLASS library. |
| | |
| | Note: OSS and fbcode will have different keys. |
| | """ |
| | if config.is_fbcode(): |
| | with importlib.resources.path( |
| | "cutlass_library", "src_hash.txt" |
| | ) as resource_path: |
| | with open(resource_path) as resource_file: |
| | return resource_file.read().encode() |
| |
|
| | combined_hash = hashlib.sha256() |
| | build_code_hash([config.cuda.cutlass_dir], "", combined_hash) |
| | return combined_hash.digest() |
| |
|
| |
|
| | def _cuda_lib_options() -> list[str]: |
| | """ |
| | Util function for CUTLASS backend to find the correct CUDA libraries. |
| | """ |
| | _set_gpu_runtime_env() |
| | from torch.utils import cpp_extension |
| |
|
| | lpaths = cpp_extension.library_paths(device_type="cuda") |
| | if use_re_build(): |
| | lpaths += [ |
| | build_paths.sdk_lib, |
| | os.path.join(build_paths.sdk_lib, "stubs"), |
| | ] |
| | extra_ldflags: list[str] = [] |
| | if is_linux(): |
| | _transform_cuda_paths(lpaths) |
| | for path in lpaths: |
| | if "torch/lib" in path: |
| | |
| | continue |
| | extra_ldflags.append(f"-L{path}") |
| | |
| | |
| | |
| | if os.path.basename(path) != "stubs": |
| | extra_ldflags.extend(["-Xlinker", f"-rpath={path}"]) |
| | extra_ldflags.append("-lcuda") |
| | extra_ldflags.append("-lcudart") |
| | else: |
| | raise NotImplementedError( |
| | "Unsupported env, failed to find cuda libs! Currently only Linux is supported." |
| | ) |
| | return extra_ldflags |
| |
|
| |
|
| | def _nvcc_host_compiler_options() -> list[str]: |
| | return [ |
| | "-fPIC", |
| | "-fno-strict-aliasing", |
| | "-fvisibility=hidden", |
| | "-Wconversion", |
| | ] |
| |
|
| |
|
| | def _nvcc_arch_as_compile_option() -> str: |
| | arch = cuda_env.get_cuda_arch() |
| | if arch == "90": |
| | |
| | return "90a" |
| | if arch == "100": |
| | return "100a" |
| | return arch |
| |
|
| |
|
| | def _nvcc_compiler_options() -> list[str]: |
| | arch = _nvcc_arch_as_compile_option() |
| | code = [f"sm_{arch}", f"compute_{arch}"] |
| | if config.cuda.enable_cuda_lto: |
| | code += [f"lto_{arch}"] |
| | options = [ |
| | "-t=0", |
| | "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", |
| | "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", |
| | "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", |
| | "-w", |
| | f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", |
| | config.cuda.compile_opt_level, |
| | "-std=c++17", |
| | "--expt-relaxed-constexpr", |
| | "-DNDEBUG", |
| | ] |
| | if config.is_fbcode(): |
| | options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) |
| | if config.cuda.enable_debug_info: |
| | options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) |
| | if config.cuda.enable_ptxas_info: |
| | options.extend( |
| | [ |
| | "--keep", |
| | "--ptxas-options=--warn-on-local-memory-usage", |
| | "--ptxas-options=--warn-on-spills", |
| | "--resource-usage", |
| | "--source-in-ptx", |
| | ] |
| | ) |
| | if config.cuda.use_fast_math: |
| | options.extend( |
| | [ |
| | "--use_fast_math", |
| | "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", |
| | ] |
| | ) |
| | return options |
| |
|
| |
|
| | def cuda_compile_command( |
| | src_files: list[str], |
| | dst_file: str, |
| | dst_file_ext: str, |
| | extra_args: Optional[list[str]] = None, |
| | ) -> str: |
| | if extra_args is None: |
| | extra_args = [] |
| | if use_re_build(): |
| | build_path = os.path.dirname(dst_file) |
| | include_paths = _clone_cutlass_paths(build_path) |
| | src_files = [os.path.basename(src_file) for src_file in src_files] |
| | dst_file = os.path.basename(dst_file) |
| | else: |
| | include_paths = _cutlass_include_paths() |
| | cuda_lib_options = _cuda_lib_options() |
| | nvcc_host_compiler_options = _nvcc_host_compiler_options() |
| | nvcc_compiler_options = _nvcc_compiler_options() |
| | options = ( |
| | nvcc_compiler_options |
| | + extra_args |
| | + [ |
| | f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" |
| | for opt in nvcc_host_compiler_options |
| | ] |
| | + ["-I" + path for path in include_paths] |
| | + cuda_lib_options |
| | ) |
| | src_file = " ".join(src_files) |
| | res = "" |
| | if dst_file_ext == "o": |
| | res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" |
| | elif dst_file_ext == "so": |
| | options.append("-shared") |
| | res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" |
| | elif dst_file_ext == "exe": |
| | res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" |
| | else: |
| | raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") |
| | if log.isEnabledFor(logging.DEBUG): |
| | log.debug("CUDA command: %s", res) |
| | else: |
| | autotuning_log.debug("CUDA command: %s", res) |
| | return res |
| |
|
| |
|
| | class DLLWrapper: |
| | """A wrapper for a dynamic library.""" |
| |
|
| | def __init__( |
| | self, |
| | lib_path: str, |
| | ) -> None: |
| | self.lib_path = lib_path |
| | self.is_open = False |
| | self.DLL = cdll.LoadLibrary(lib_path) |
| | self.is_open = True |
| |
|
| | def close(self) -> None: |
| | if self.is_open: |
| | self._dlclose() |
| | self.is_open = False |
| |
|
| | def _dlclose(self) -> None: |
| | f_dlclose = None |
| |
|
| | if is_linux(): |
| | syms = CDLL(None) |
| | if not hasattr(syms, "dlclose"): |
| | |
| | syms = CDLL("libc.so") |
| |
|
| | if hasattr(syms, "dlclose"): |
| | f_dlclose = syms.dlclose |
| | elif is_windows(): |
| | import ctypes |
| |
|
| | kernel32 = ctypes.CDLL("kernel32", use_last_error=True) |
| |
|
| | f_dlclose = kernel32.FreeLibrary |
| | else: |
| | raise NotImplementedError("Unsupported env, failed to do dlclose!") |
| |
|
| | if f_dlclose is not None: |
| | if is_linux(): |
| | f_dlclose.argtypes = [c_void_p] |
| | f_dlclose(self.DLL._handle) |
| | elif is_windows(): |
| | import ctypes |
| | from ctypes import wintypes |
| |
|
| | f_dlclose.argtypes = [wintypes.HMODULE] |
| | f_dlclose(self.DLL._handle) |
| | else: |
| | log.warning( |
| | "dll unloading function was not found, library may not be unloaded properly!" |
| | ) |
| |
|
| | def __getattr__(self, name: str) -> Callable[..., None]: |
| | if not self.is_open: |
| | raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") |
| |
|
| | method = getattr(self.DLL, name) |
| |
|
| | def _wrapped_func(*args: Any) -> None: |
| | err = method(*args) |
| | if err: |
| | raise RuntimeError(f"Error in function: {method.__name__}") |
| |
|
| | return _wrapped_func |
| |
|
| | def __enter__(self) -> Self: |
| | return self |
| |
|
| | def __exit__(self, *args: Any) -> None: |
| | self.close() |
| |
|
| | def __del__(self) -> None: |
| | self.close() |
| |
|
| |
|
| | @lru_cache |
| | def binary_error_path(output_path: str) -> str: |
| | """ |
| | standard format for the error path |
| | """ |
| | return output_path + ".error" |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class CUDACodeCache: |
| | """ |
| | A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS. |
| | This class handles writing source code to files, compiling them into shared objects, and caching |
| | the results to avoid redundant compilations. It also manages error handling and logging for the |
| | compilation process. |
| | """ |
| |
|
| | @dataclasses.dataclass |
| | class CacheEntry: |
| | input_path: str |
| | output_path: str |
| | error_json: Optional[str] = None |
| |
|
| | cache: dict[str, CacheEntry] = {} |
| | aot_kernels_o: list[str] = [] |
| | _SOURCE_CODE_SUFFIX = "cu" |
| |
|
| | @staticmethod |
| | def cache_clear() -> None: |
| | CUDACodeCache.cache.clear() |
| | CUDACodeCache.aot_kernels_o.clear() |
| |
|
| | @staticmethod |
| | @lru_cache(maxsize=4) |
| | def get_kernel_binary_remote_cache( |
| | caching_enabled: bool, caching_available: bool |
| | ) -> Optional[Any]: |
| | """ |
| | Get or create the class instance of the CUTLASSKernelBinaryRemoteCache. |
| | |
| | Args: |
| | caching_enabled: Whether binary remote caching is enabled |
| | caching_available: Whether we're in fbcode environment |
| | |
| | Returns: |
| | CUTLASSKernelBinaryRemoteCache: The class instance of the kernel binary remote cache |
| | """ |
| | if not caching_enabled: |
| | log.debug("CUTLASSKernelBinaryRemoteCache not requested, skipping") |
| | return None |
| | if not caching_available: |
| | return None |
| |
|
| | try: |
| | from torch._inductor.fb.kernel_binary_remote_cache import ( |
| | CUTLASSKernelBinaryRemoteCache, |
| | ) |
| |
|
| | return CUTLASSKernelBinaryRemoteCache() |
| | except ImportError: |
| | log.debug( |
| | "CUTLASSKernelBinaryRemoteCache not available, remote caching disabled" |
| | ) |
| | return None |
| |
|
| | @classmethod |
| | @lru_cache(None) |
| | def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: |
| | """ |
| | Writes source code into a file with dst_file_ext as the file extension. |
| | Returns the hash key of source code, and the path to the file. |
| | """ |
| |
|
| | if config.cuda.cutlass_hash_with_compile_cmd: |
| | cuda_command = repr( |
| | cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) |
| | ) |
| | extra = cuda_command |
| | else: |
| | extra = repr( |
| | [ |
| | |
| | _cuda_compiler(), |
| | |
| | _nvcc_compiler_options(), |
| | |
| | _nvcc_host_compiler_options(), |
| | |
| | cutlass_key(), |
| | |
| | ] |
| | ) |
| | key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra) |
| | return key, input_path |
| |
|
| | @classmethod |
| | def compile( |
| | cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None |
| | ) -> tuple[str, str, str]: |
| | """ |
| | Compiles CUDA source_code into a file with dst_file_ext extension. |
| | If dst_file_ext is "so", first compiles to ".o" and then links to ".so". |
| | Returns a tuple of dst_file_path, hash_key, source_code_path |
| | """ |
| | if dst_file_ext == "so": |
| | |
| | obj_path, _, _ = cls.compile(source_code, "o", extra_args) |
| | key, input_path = cls.write(source_code, dst_file_ext) |
| | src_files, operation_name = [obj_path], "Linking" |
| | else: |
| | |
| | key, input_path = cls.write(source_code, dst_file_ext) |
| | src_files, operation_name = [input_path], "Compilation" |
| |
|
| | key_with_ext = key + dst_file_ext |
| | if key_with_ext not in cls.cache: |
| | from torch.utils._filelock import FileLock |
| |
|
| | lock_dir = get_lock_dir() |
| | lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) |
| | with lock: |
| | output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext |
| | error_path = binary_error_path(output_path) |
| | binary_remote_cache = cls.get_kernel_binary_remote_cache( |
| | caching_enabled=config.cuda.use_binary_remote_cache |
| | and not config.force_disable_caches, |
| | caching_available=config.is_fbcode(), |
| | ) |
| | if binary_remote_cache is not None: |
| | |
| | |
| | binary_remote_cache.get(output_path, error_path) |
| |
|
| | if os.path.exists(error_path): |
| | with open(error_path, encoding="utf-8") as fh: |
| | error_json = fh.read() |
| | cmd_parts, error_output = json.loads(error_json) |
| | if ( |
| | binary_remote_cache is not None |
| | and config.cuda.upload_to_binary_remote_cache |
| | ): |
| | |
| | |
| | |
| | binary_remote_cache.put( |
| | error_path, config.cuda.binary_remote_cache_force_write |
| | ) |
| | cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( |
| | input_path, output_path, error_json |
| | ) |
| | raise exc.CUDACompileError(cmd_parts, error_output) |
| | if not os.path.exists(output_path): |
| | cmd = cuda_compile_command( |
| | src_files, output_path, dst_file_ext, extra_args |
| | ) |
| | with open(input_path, "a") as f: |
| | f.write("\n") |
| | f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n") |
| | start_time = time() |
| | log.debug("CUDA %s: %s", operation_name, cmd) |
| | cmd_parts = cmd.split(" ") |
| | try: |
| | if use_re_build(): |
| | from triton.fb.re_build_helper import run_build_command |
| |
|
| | run_build_command( |
| | cmd_parts, |
| | os.path.dirname(input_path), |
| | os.path.basename(output_path), |
| | ) |
| | else: |
| | subprocess.check_output( |
| | cmd_parts, stderr=subprocess.STDOUT, env=os.environ |
| | ) |
| | except subprocess.CalledProcessError as error: |
| | cls._record_cuda_compile_error( |
| | error.output.decode("utf-8"), |
| | key_with_ext, |
| | cmd_parts, |
| | input_path, |
| | output_path, |
| | binary_remote_cache, |
| | ) |
| | raise exc.CUDACompileError(cmd_parts, error.output) from error |
| | except Exception as error: |
| | if "COMPILE FAILED WITH" in str(error): |
| | cls._record_cuda_compile_error( |
| | str(error), |
| | key_with_ext, |
| | cmd_parts, |
| | input_path, |
| | output_path, |
| | binary_remote_cache, |
| | ) |
| | raise exc.CUDACompileError(cmd_parts, str(error)) from error |
| | raise error |
| | end_time = time() |
| | log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}" |
| | log.info(log_duration_msg) |
| |
|
| | else: |
| | log.debug( |
| | "CUDA %s skipped: %s since output already exists", |
| | operation_name, |
| | output_path, |
| | ) |
| | |
| | if ( |
| | binary_remote_cache is not None |
| | and config.cuda.upload_to_binary_remote_cache |
| | ): |
| | |
| | binary_remote_cache.put( |
| | output_path, config.cuda.binary_remote_cache_force_write |
| | ) |
| | cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( |
| | input_path, output_path, None |
| | ) |
| |
|
| | cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext] |
| | if cache_entry.error_json is not None: |
| | |
| | cmd_parts, error_output = json.loads(cache_entry.error_json) |
| | raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) |
| | return (cls.cache[key_with_ext].output_path, key, input_path) |
| |
|
| | @classmethod |
| | def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: |
| | """ |
| | Compiles source code and loads the generated .so file. |
| | Returns a tuple of DLLWrapper, hash_key, source_code_path |
| | """ |
| |
|
| | if dst_file_ext != "so": |
| | raise RuntimeError( |
| | f"Only support loading a .so file for now. " |
| | f"Requested file extension: {dst_file_ext}. Source code: {source_code}" |
| | ) |
| | dst_file_path, hash_key, source_code_path = cls.compile( |
| | source_code, dst_file_ext |
| | ) |
| | return (DLLWrapper(dst_file_path), hash_key, source_code_path) |
| |
|
| | @classmethod |
| | def _record_cuda_compile_error( |
| | cls, |
| | error_str: str, |
| | key_with_ext: str, |
| | cmd_parts: list[str], |
| | input_path: str, |
| | output_path: str, |
| | |
| | |
| | binary_remote_cache: Any = None, |
| | ) -> None: |
| | error_json = json.dumps([cmd_parts, error_str]) |
| | cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( |
| | input_path, output_path, error_json |
| | ) |
| | error_path = binary_error_path(output_path) |
| | with open(error_path, "w", encoding="utf-8") as fh: |
| | fh.write(error_json) |
| |
|
| | |
| | if ( |
| | binary_remote_cache is not None |
| | and config.cuda.upload_to_binary_remote_cache |
| | ): |
| | binary_remote_cache.put( |
| | error_path, config.cuda.binary_remote_cache_force_write |
| | ) |
| |
|
| |
|
| | @clear_on_fresh_cache |
| | class ROCmCodeCache: |
| | @dataclasses.dataclass |
| | class CacheEntry: |
| | input_path: str |
| | output_path: str |
| |
|
| | cache: dict[str, CacheEntry] = {} |
| | aot_kernels_o: list[str] = [] |
| | _SOURCE_CODE_SUFFIX = "cpp" |
| | _logged_compiler_version = False |
| |
|
| | @staticmethod |
| | def cache_clear() -> None: |
| | ROCmCodeCache.cache.clear() |
| | ROCmCodeCache.aot_kernels_o.clear() |
| |
|
| | @classmethod |
| | def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: |
| | """ |
| | Writes source code into a file with dst_file_ext as the file extension. |
| | Returns the hash key of source code, and the path to the file. |
| | """ |
| |
|
| | cuda_command = repr( |
| | rocm_compile_command(["dummy_input"], "dummy_output", dst_file_ext) |
| | ) |
| | key, input_path = write( |
| | source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command |
| | ) |
| | return key, input_path |
| |
|
| | @classmethod |
| | def compile( |
| | cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None |
| | ) -> tuple[str, str, str]: |
| | """ |
| | Compiles source_code into a file with dst_file_ext extension, |
| | using the compile command specific for the ROCm platform. |
| | Returns a tuple of dst_file_path, hash_key, source_code_path |
| | """ |
| | if not cls._logged_compiler_version: |
| | cls._logged_compiler_version = True |
| | log.debug(get_compiler_version_info(str(rocm_compiler()))) |
| |
|
| | key, input_path = cls.write(source_code, dst_file_ext) |
| | if key not in cls.cache: |
| | from torch.utils._filelock import FileLock |
| |
|
| | lock_dir = get_lock_dir() |
| | lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) |
| | with lock: |
| | output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext |
| | if not os.path.exists(output_path): |
| | cmd = rocm_compile_command( |
| | [input_path], output_path, dst_file_ext, extra_args |
| | ) |
| | start_time = time() |
| | cmd_parts = cmd.split(" ") |
| | try: |
| | output = subprocess.check_output( |
| | cmd_parts, |
| | stderr=subprocess.STDOUT, |
| | text=True, |
| | env=os.environ, |
| | ) |
| | log.debug("Compilation output: %s", output) |
| | except subprocess.CalledProcessError as error: |
| | raise exc.CUDACompileError(cmd_parts, error.output) from error |
| | end_time = time() |
| | log_duration_msg = f"Compilation took {end_time - start_time} seconds. Compile command: {cmd}" |
| | log.info(log_duration_msg) |
| | else: |
| | log.debug( |
| | "Skip compiling %s: output %s already exists", |
| | input_path, |
| | output_path, |
| | ) |
| | cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path) |
| |
|
| | return (cls.cache[key].output_path, key, input_path) |
| |
|
| | @classmethod |
| | def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: |
| | """ |
| | Compiles source code and loads the generated .so file. |
| | Returns a tuple of DLLWrapper, hash_key, source_code_path |
| | """ |
| |
|
| | if dst_file_ext != "so": |
| | raise RuntimeError( |
| | f"Only support loading a .so file for now. " |
| | f"Requested file extension: {dst_file_ext}. Source code: {source_code}" |
| | ) |
| | dst_file_path, hash_key, source_code_path = cls.compile( |
| | source_code, dst_file_ext |
| | ) |
| | return (DLLWrapper(dst_file_path), hash_key, source_code_path) |
| |
|
| |
|
| | class CodeCacheFuture: |
| | def result(self) -> Callable[..., Any]: |
| | raise NotImplementedError |
| |
|
| |
|
| | class LambdaFuture(CodeCacheFuture): |
| | def __init__( |
| | self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None |
| | ) -> None: |
| | self.result_fn = result_fn |
| | self.future = future |
| |
|
| | def result(self) -> Callable[..., Any]: |
| | return self.result_fn() |
| |
|
| |
|
| | class StaticAutotunerFuture(CodeCacheFuture): |
| | """ |
| | A statically launchable CachingAutotuner, loaded from TritonBundler |
| | """ |
| |
|
| | def __init__(self, static_autotuner: CachingAutotuner) -> None: |
| | |
| | self.static_autotuner = static_autotuner |
| | |
| | |
| | |
| | |
| | self.reload_kernel_from_src: Optional[Callable[[], Any]] = None |
| |
|
| | def result(self) -> CachingAutotuner: |
| | assert self.reload_kernel_from_src is not None |
| | with dynamo_timed("StaticAutotunerFuture.warm_precompile"): |
| | self.static_autotuner.recheck_autotune_cache( |
| | reload_kernel_from_src=self.reload_kernel_from_src |
| | ) |
| | self.static_autotuner.precompile( |
| | warm_cache_only=False, |
| | reload_kernel=self.reload_kernel_from_src, |
| | static_triton_bundle_key=None, |
| | ) |
| | return self.static_autotuner |
| |
|