Instructions to use kernels-community/flash-attn4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn4 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn4") - Notebooks
- Google Colab
- Kaggle
| # Manage Ahead-of-Time (AOT) compiled kernels | |
| import fcntl | |
| import hashlib | |
| import logging | |
| import os | |
| import pickle | |
| import sys | |
| import tempfile | |
| import time | |
| from functools import lru_cache | |
| from getpass import getuser | |
| from pathlib import Path | |
| from typing import Hashable, TypeAlias | |
| import ctypes | |
| import cutlass | |
| import cutlass.cute as cute | |
| import tvm_ffi | |
| from cutlass.cutlass_dsl import JitCompiledFunction | |
| # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols | |
| # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. | |
| # Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes | |
| # "undefined symbol" errors when loading cached kernels from disk. | |
| for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False): | |
| if Path(_lib_path).exists(): | |
| ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL) | |
| CompileKeyType: TypeAlias = tuple[Hashable, ...] | |
| CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function | |
| logger = logging.getLogger(__name__) | |
| _handler = logging.StreamHandler() | |
| _handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) | |
| logger.addHandler(_handler) | |
| logger.setLevel(logging.DEBUG) | |
| # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` | |
| CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" | |
| # Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is | |
| # `/tmp/${USER}/flash_attention_cute_dsl_cache`` | |
| CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None) | |
| def get_cache_path() -> Path: | |
| if CUTE_DSL_CACHE_DIR is not None: | |
| cache_dir = Path(CUTE_DSL_CACHE_DIR) | |
| else: | |
| cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return cache_dir | |
| def _compute_source_fingerprint() -> str: | |
| """ | |
| Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint. | |
| The fingerprint changes whenever: | |
| - Any .py file under flash_attn/cute is added, removed, renamed, or modified. | |
| - The Python minor version changes (e.g. 3.13 -> 3.14). | |
| - The cutlass or tvm_ffi package version changes. | |
| Computed once per process and cached. | |
| """ | |
| cute_root = Path(__file__).resolve().parent | |
| h = hashlib.sha256() | |
| h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode()) | |
| h.update(f"cutlass={cutlass.__version__}".encode()) | |
| h.update(f"tvm_ffi={tvm_ffi.__version__}".encode()) | |
| for src in sorted(cute_root.rglob("*.py")): | |
| if not src.is_file(): | |
| continue | |
| h.update(src.relative_to(cute_root).as_posix().encode()) | |
| content = src.read_bytes() | |
| h.update(len(content).to_bytes(8, "little")) | |
| h.update(content) | |
| return h.hexdigest() | |
| class FileLock: | |
| """Context manager for advisory file locks using fcntl.flock. | |
| Supports exclusive (write) and shared (read) locks. | |
| Always blocks with polling until the lock is acquired or timeout is reached. | |
| Usage: | |
| with FileLock(lock_path, exclusive=True, timeout=15, label="abc"): | |
| # do work under lock | |
| """ | |
| def __init__( | |
| self, | |
| lock_path: Path, | |
| exclusive: bool, | |
| timeout: float = 15, | |
| label: str = "", | |
| ): | |
| """ | |
| Args: | |
| lock_path: Path to the lock file on disk. | |
| exclusive: True for exclusive (write) lock, False for shared (read) lock. | |
| timeout: Max seconds to wait for lock acquisition before raising RuntimeError. | |
| label: Optional human-readable label for error messages. | |
| """ | |
| self.lock_path: Path = lock_path | |
| self.exclusive: bool = exclusive | |
| self.timeout: float = timeout | |
| self.label: str = label | |
| self._fd: int = -1 | |
| def _lock_label(self) -> str: | |
| kind = "exclusive" if self.exclusive else "shared" | |
| return f"{kind} {self.label}" if self.label else kind | |
| def __enter__(self) -> "FileLock": | |
| open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT | |
| lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH | |
| self._fd = os.open(str(self.lock_path), open_flags) | |
| deadline = time.monotonic() + self.timeout | |
| acquired = False | |
| while time.monotonic() < deadline: | |
| try: | |
| fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB) | |
| acquired = True | |
| break | |
| except OSError: | |
| time.sleep(0.1) | |
| if not acquired: | |
| os.close(self._fd) | |
| self._fd = None | |
| raise RuntimeError( | |
| f"Timed out after {self.timeout}s waiting for " | |
| f"{self._lock_label} lock: {self.lock_path}" | |
| ) | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| if self._fd is not None: | |
| fcntl.flock(self._fd, fcntl.LOCK_UN) | |
| os.close(self._fd) | |
| self._fd = None | |
| class JITCache: | |
| """ | |
| In-memory cache for compiled functions. | |
| """ | |
| def __init__(self): | |
| self.cache: dict[CompileKeyType, CallableFunction] = {} | |
| def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: | |
| self.cache[key] = fn | |
| def __getitem__(self, key: CompileKeyType) -> CallableFunction: | |
| return self.cache[key] | |
| def __contains__(self, key: CompileKeyType) -> bool: | |
| return key in self.cache | |
| def clear(self) -> None: | |
| """ | |
| Clear in-memory cache of compiled functions | |
| """ | |
| self.cache.clear() | |
| class JITPersistentCache(JITCache): | |
| """ | |
| In-memory cache for compiled functions, which is also backed by persistent storage. | |
| Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True | |
| """ | |
| EXPORT_FUNCTION_PREFIX = "func" | |
| LOCK_TIMEOUT_SECONDS = 15 | |
| def __init__(self, cache_path: Path): | |
| super().__init__() | |
| cache_path.mkdir(parents=True, exist_ok=True) | |
| self.cache_path: Path = cache_path | |
| def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: | |
| JITCache.__setitem__(self, key, fn) | |
| self._try_export_to_storage(key, fn) | |
| def __getitem__(self, key: CompileKeyType) -> CallableFunction: | |
| # Use __contains__ to try populating in-memory cache with persistent storage | |
| self.__contains__(key) | |
| return JITCache.__getitem__(self, key) | |
| def __contains__(self, key: CompileKeyType) -> bool: | |
| # Checks in-memory cache first, then tries loading from storage. | |
| # When returning True, guarantees the in-memory cache is populated. | |
| if JITCache.__contains__(self, key): | |
| return True | |
| return self._try_load_from_storage(key) | |
| def _try_load_from_storage(self, key: CompileKeyType) -> bool: | |
| """ | |
| Try to load a function from persistent storage into in-memory cache. | |
| Returns True if loaded successfully, False if not found on disk. | |
| Holds a shared lock during loading to prevent concurrent writes. | |
| """ | |
| sha256_hex = self._key_to_hash(key) | |
| obj_path = self.cache_path / f"{sha256_hex}.o" | |
| with FileLock( | |
| self._lock_path(sha256_hex), | |
| exclusive=False, | |
| timeout=self.LOCK_TIMEOUT_SECONDS, | |
| label=sha256_hex, | |
| ): | |
| if obj_path.exists(): | |
| logger.debug("Loading compiled function from disk: %s", obj_path) | |
| m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) | |
| fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) | |
| JITCache.__setitem__(self, key, fn) | |
| return True | |
| else: | |
| logger.debug("Cache miss on disk for key hash %s", sha256_hex) | |
| return False | |
| def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: | |
| """Export a compiled function to persistent storage under exclusive lock.""" | |
| sha256_hex = self._key_to_hash(key) | |
| with FileLock( | |
| self._lock_path(sha256_hex), | |
| exclusive=True, | |
| timeout=self.LOCK_TIMEOUT_SECONDS, | |
| label=sha256_hex, | |
| ): | |
| obj_path = self.cache_path / f"{sha256_hex}.o" | |
| if obj_path.exists(): | |
| # Another process already exported. | |
| logger.debug("Skipping export, already on disk: %s", obj_path) | |
| return | |
| logger.debug("Exporting compiled function to disk: %s", obj_path) | |
| fn.export_to_c( | |
| object_file_path=str(obj_path), | |
| function_name=self.EXPORT_FUNCTION_PREFIX, | |
| ) | |
| logger.debug("Successfully exported compiled function to disk: %s", obj_path) | |
| def _key_to_hash(self, key: CompileKeyType) -> str: | |
| return hashlib.sha256(pickle.dumps(key)).hexdigest() | |
| def _lock_path(self, sha256_hex: str) -> Path: | |
| return self.cache_path / f"{sha256_hex}.lock" | |
| def clear(self) -> None: | |
| """ | |
| Not only clear the in-memory cache. Also purge persistent compilation cache. | |
| """ | |
| logger.debug("Clearing persistent cache at %s", self.cache_path) | |
| super().clear() | |
| for child in self.cache_path.iterdir(): | |
| child.unlink() | |
| def get_jit_cache(name: str | None = None) -> JITCache: | |
| """ | |
| JIT cache factory. | |
| `name` is an optional identifier to create subdirectories to manage cache. | |
| When persistent caching is enabled, artifacts are namespaced under a | |
| source fingerprint directory so that code or dependency changes | |
| automatically invalidate stale entries. | |
| """ | |
| if CUTE_DSL_CACHE_ENABLED: | |
| path = get_cache_path() / _compute_source_fingerprint() | |
| if name: | |
| path = path / name | |
| logger.debug("Creating persistent JIT cache at %s", path) | |
| return JITPersistentCache(path) | |
| else: | |
| logger.debug("Persistent cache disabled, using in-memory JIT cache") | |
| return JITCache() | |