Instructions to use kernels-community/flash-attn4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn4 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn4") - Notebooks
- Google Colab
- Kaggle
File size: 10,231 Bytes
70b4af3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 | # Manage Ahead-of-Time (AOT) compiled kernels
import fcntl
import hashlib
import logging
import os
import pickle
import sys
import tempfile
import time
from functools import lru_cache
from getpass import getuser
from pathlib import Path
from typing import Hashable, TypeAlias
import ctypes
import cutlass
import cutlass.cute as cute
import tvm_ffi
from cutlass.cutlass_dsl import JitCompiledFunction
# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
# Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes
# "undefined symbol" errors when loading cached kernels from disk.
for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):
if Path(_lib_path).exists():
ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL)
CompileKeyType: TypeAlias = tuple[Hashable, ...]
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
logger = logging.getLogger(__name__)
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(_handler)
logger.setLevel(logging.DEBUG)
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
# Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is
# `/tmp/${USER}/flash_attention_cute_dsl_cache``
CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None)
def get_cache_path() -> Path:
if CUTE_DSL_CACHE_DIR is not None:
cache_dir = Path(CUTE_DSL_CACHE_DIR)
else:
cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
@lru_cache(maxsize=1)
def _compute_source_fingerprint() -> str:
"""
Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint.
The fingerprint changes whenever:
- Any .py file under flash_attn/cute is added, removed, renamed, or modified.
- The Python minor version changes (e.g. 3.13 -> 3.14).
- The cutlass or tvm_ffi package version changes.
Computed once per process and cached.
"""
cute_root = Path(__file__).resolve().parent
h = hashlib.sha256()
h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode())
h.update(f"cutlass={cutlass.__version__}".encode())
h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
for src in sorted(cute_root.rglob("*.py")):
if not src.is_file():
continue
h.update(src.relative_to(cute_root).as_posix().encode())
content = src.read_bytes()
h.update(len(content).to_bytes(8, "little"))
h.update(content)
return h.hexdigest()
class FileLock:
"""Context manager for advisory file locks using fcntl.flock.
Supports exclusive (write) and shared (read) locks.
Always blocks with polling until the lock is acquired or timeout is reached.
Usage:
with FileLock(lock_path, exclusive=True, timeout=15, label="abc"):
# do work under lock
"""
def __init__(
self,
lock_path: Path,
exclusive: bool,
timeout: float = 15,
label: str = "",
):
"""
Args:
lock_path: Path to the lock file on disk.
exclusive: True for exclusive (write) lock, False for shared (read) lock.
timeout: Max seconds to wait for lock acquisition before raising RuntimeError.
label: Optional human-readable label for error messages.
"""
self.lock_path: Path = lock_path
self.exclusive: bool = exclusive
self.timeout: float = timeout
self.label: str = label
self._fd: int = -1
@property
def _lock_label(self) -> str:
kind = "exclusive" if self.exclusive else "shared"
return f"{kind} {self.label}" if self.label else kind
def __enter__(self) -> "FileLock":
open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
self._fd = os.open(str(self.lock_path), open_flags)
deadline = time.monotonic() + self.timeout
acquired = False
while time.monotonic() < deadline:
try:
fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)
acquired = True
break
except OSError:
time.sleep(0.1)
if not acquired:
os.close(self._fd)
self._fd = None
raise RuntimeError(
f"Timed out after {self.timeout}s waiting for "
f"{self._lock_label} lock: {self.lock_path}"
)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if self._fd is not None:
fcntl.flock(self._fd, fcntl.LOCK_UN)
os.close(self._fd)
self._fd = None
class JITCache:
"""
In-memory cache for compiled functions.
"""
def __init__(self):
self.cache: dict[CompileKeyType, CallableFunction] = {}
def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
self.cache[key] = fn
def __getitem__(self, key: CompileKeyType) -> CallableFunction:
return self.cache[key]
def __contains__(self, key: CompileKeyType) -> bool:
return key in self.cache
def clear(self) -> None:
"""
Clear in-memory cache of compiled functions
"""
self.cache.clear()
class JITPersistentCache(JITCache):
"""
In-memory cache for compiled functions, which is also backed by persistent storage.
Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True
"""
EXPORT_FUNCTION_PREFIX = "func"
LOCK_TIMEOUT_SECONDS = 15
def __init__(self, cache_path: Path):
super().__init__()
cache_path.mkdir(parents=True, exist_ok=True)
self.cache_path: Path = cache_path
def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
JITCache.__setitem__(self, key, fn)
self._try_export_to_storage(key, fn)
def __getitem__(self, key: CompileKeyType) -> CallableFunction:
# Use __contains__ to try populating in-memory cache with persistent storage
self.__contains__(key)
return JITCache.__getitem__(self, key)
def __contains__(self, key: CompileKeyType) -> bool:
# Checks in-memory cache first, then tries loading from storage.
# When returning True, guarantees the in-memory cache is populated.
if JITCache.__contains__(self, key):
return True
return self._try_load_from_storage(key)
def _try_load_from_storage(self, key: CompileKeyType) -> bool:
"""
Try to load a function from persistent storage into in-memory cache.
Returns True if loaded successfully, False if not found on disk.
Holds a shared lock during loading to prevent concurrent writes.
"""
sha256_hex = self._key_to_hash(key)
obj_path = self.cache_path / f"{sha256_hex}.o"
with FileLock(
self._lock_path(sha256_hex),
exclusive=False,
timeout=self.LOCK_TIMEOUT_SECONDS,
label=sha256_hex,
):
if obj_path.exists():
logger.debug("Loading compiled function from disk: %s", obj_path)
m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
JITCache.__setitem__(self, key, fn)
return True
else:
logger.debug("Cache miss on disk for key hash %s", sha256_hex)
return False
def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
"""Export a compiled function to persistent storage under exclusive lock."""
sha256_hex = self._key_to_hash(key)
with FileLock(
self._lock_path(sha256_hex),
exclusive=True,
timeout=self.LOCK_TIMEOUT_SECONDS,
label=sha256_hex,
):
obj_path = self.cache_path / f"{sha256_hex}.o"
if obj_path.exists():
# Another process already exported.
logger.debug("Skipping export, already on disk: %s", obj_path)
return
logger.debug("Exporting compiled function to disk: %s", obj_path)
fn.export_to_c(
object_file_path=str(obj_path),
function_name=self.EXPORT_FUNCTION_PREFIX,
)
logger.debug("Successfully exported compiled function to disk: %s", obj_path)
def _key_to_hash(self, key: CompileKeyType) -> str:
return hashlib.sha256(pickle.dumps(key)).hexdigest()
def _lock_path(self, sha256_hex: str) -> Path:
return self.cache_path / f"{sha256_hex}.lock"
def clear(self) -> None:
"""
Not only clear the in-memory cache. Also purge persistent compilation cache.
"""
logger.debug("Clearing persistent cache at %s", self.cache_path)
super().clear()
for child in self.cache_path.iterdir():
child.unlink()
def get_jit_cache(name: str | None = None) -> JITCache:
"""
JIT cache factory.
`name` is an optional identifier to create subdirectories to manage cache.
When persistent caching is enabled, artifacts are namespaced under a
source fingerprint directory so that code or dependency changes
automatically invalidate stale entries.
"""
if CUTE_DSL_CACHE_ENABLED:
path = get_cache_path() / _compute_source_fingerprint()
if name:
path = path / name
logger.debug("Creating persistent JIT cache at %s", path)
return JITPersistentCache(path)
else:
logger.debug("Persistent cache disabled, using in-memory JIT cache")
return JITCache()
|