Build uploaded using `kernels`.
Browse files- build/torch-cuda/_ops.py +2 -2
- build/torch-cuda/cache_utils.py +30 -49
build/torch-cuda/_ops.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
-
ops = torch.ops.
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
ops = torch.ops._flash_attn4_c07a63b_dirty
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
+
return f"_flash_attn4_c07a63b_dirty::{op_name}"
|
build/torch-cuda/cache_utils.py
CHANGED
|
@@ -7,23 +7,34 @@ import pickle
|
|
| 7 |
import sys
|
| 8 |
import tempfile
|
| 9 |
import time
|
| 10 |
-
from distutils.ccompiler import CCompiler, new_compiler
|
| 11 |
from functools import lru_cache
|
| 12 |
from getpass import getuser
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Hashable, TypeAlias
|
| 15 |
|
|
|
|
|
|
|
| 16 |
import cutlass
|
| 17 |
import cutlass.cute as cute
|
| 18 |
import tvm_ffi
|
| 19 |
from cutlass.cutlass_dsl import JitCompiledFunction
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
CompileKeyType: TypeAlias = tuple[Hashable, ...]
|
| 22 |
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
|
|
@@ -64,6 +75,8 @@ def _compute_source_fingerprint() -> str:
|
|
| 64 |
h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
|
| 65 |
|
| 66 |
for src in sorted(cute_root.rglob("*.py")):
|
|
|
|
|
|
|
| 67 |
h.update(src.relative_to(cute_root).as_posix().encode())
|
| 68 |
content = src.read_bytes()
|
| 69 |
h.update(len(content).to_bytes(8, "little"))
|
|
@@ -109,9 +122,7 @@ class FileLock:
|
|
| 109 |
return f"{kind} {self.label}" if self.label else kind
|
| 110 |
|
| 111 |
def __enter__(self) -> "FileLock":
|
| 112 |
-
open_flags =
|
| 113 |
-
os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
|
| 114 |
-
)
|
| 115 |
lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
|
| 116 |
|
| 117 |
self._fd = os.open(str(self.lock_path), open_flags)
|
|
@@ -175,8 +186,6 @@ class JITPersistentCache(JITCache):
|
|
| 175 |
EXPORT_FUNCTION_PREFIX = "func"
|
| 176 |
LOCK_TIMEOUT_SECONDS = 15
|
| 177 |
|
| 178 |
-
_compiler: CCompiler | None = None
|
| 179 |
-
|
| 180 |
def __init__(self, cache_path: Path):
|
| 181 |
super().__init__()
|
| 182 |
cache_path.mkdir(parents=True, exist_ok=True)
|
|
@@ -205,32 +214,24 @@ class JITPersistentCache(JITCache):
|
|
| 205 |
Holds a shared lock during loading to prevent concurrent writes.
|
| 206 |
"""
|
| 207 |
sha256_hex = self._key_to_hash(key)
|
| 208 |
-
|
| 209 |
with FileLock(
|
| 210 |
self._lock_path(sha256_hex),
|
| 211 |
exclusive=False,
|
| 212 |
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 213 |
label=sha256_hex,
|
| 214 |
):
|
| 215 |
-
if
|
| 216 |
-
logger.debug(
|
| 217 |
-
|
| 218 |
-
)
|
| 219 |
-
m = cute.runtime.load_module(
|
| 220 |
-
str(so_path), enable_tvm_ffi=True
|
| 221 |
-
)
|
| 222 |
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
|
| 223 |
JITCache.__setitem__(self, key, fn)
|
| 224 |
return True
|
| 225 |
else:
|
| 226 |
-
logger.debug(
|
| 227 |
-
"Cache miss on disk for key hash %s", sha256_hex
|
| 228 |
-
)
|
| 229 |
return False
|
| 230 |
|
| 231 |
-
def _try_export_to_storage(
|
| 232 |
-
self, key: CompileKeyType, fn: JitCompiledFunction
|
| 233 |
-
) -> None:
|
| 234 |
"""Export a compiled function to persistent storage under exclusive lock."""
|
| 235 |
sha256_hex = self._key_to_hash(key)
|
| 236 |
with FileLock(
|
|
@@ -239,33 +240,17 @@ class JITPersistentCache(JITCache):
|
|
| 239 |
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 240 |
label=sha256_hex,
|
| 241 |
):
|
| 242 |
-
|
| 243 |
-
if
|
| 244 |
# Another process already exported.
|
| 245 |
-
logger.debug(
|
| 246 |
-
"Skipping export, already on disk: %s", so_path
|
| 247 |
-
)
|
| 248 |
return
|
| 249 |
-
|
| 250 |
-
logger.debug(
|
| 251 |
-
"Exporting compiled function to disk: %s", so_path
|
| 252 |
-
)
|
| 253 |
fn.export_to_c(
|
| 254 |
object_file_path=str(obj_path),
|
| 255 |
function_name=self.EXPORT_FUNCTION_PREFIX,
|
| 256 |
)
|
| 257 |
-
|
| 258 |
-
# "relocatable" .o files. But tvm_ffi expects "shared library" .so
|
| 259 |
-
# files. Link ourselves to workaround.
|
| 260 |
-
if JITPersistentCache._compiler is None:
|
| 261 |
-
JITPersistentCache._compiler = new_compiler()
|
| 262 |
-
JITPersistentCache._compiler.link_shared_object(
|
| 263 |
-
[str(obj_path)], str(so_path)
|
| 264 |
-
)
|
| 265 |
-
obj_path.unlink()
|
| 266 |
-
logger.debug(
|
| 267 |
-
"Successfully exported compiled function to disk: %s", so_path
|
| 268 |
-
)
|
| 269 |
|
| 270 |
def _key_to_hash(self, key: CompileKeyType) -> str:
|
| 271 |
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
|
@@ -277,9 +262,7 @@ class JITPersistentCache(JITCache):
|
|
| 277 |
"""
|
| 278 |
Not only clear the in-memory cache. Also purge persistent compilation cache.
|
| 279 |
"""
|
| 280 |
-
logger.debug(
|
| 281 |
-
"Clearing persistent cache at %s", self.cache_path
|
| 282 |
-
)
|
| 283 |
super().clear()
|
| 284 |
for child in self.cache_path.iterdir():
|
| 285 |
child.unlink()
|
|
@@ -298,9 +281,7 @@ def get_jit_cache(name: str | None = None) -> JITCache:
|
|
| 298 |
path = get_cache_path() / _compute_source_fingerprint()
|
| 299 |
if name:
|
| 300 |
path = path / name
|
| 301 |
-
logger.debug(
|
| 302 |
-
"Creating persistent JIT cache at %s", path
|
| 303 |
-
)
|
| 304 |
return JITPersistentCache(path)
|
| 305 |
else:
|
| 306 |
logger.debug("Persistent cache disabled, using in-memory JIT cache")
|
|
|
|
| 7 |
import sys
|
| 8 |
import tempfile
|
| 9 |
import time
|
|
|
|
| 10 |
from functools import lru_cache
|
| 11 |
from getpass import getuser
|
| 12 |
from pathlib import Path
|
| 13 |
from typing import Hashable, TypeAlias
|
| 14 |
|
| 15 |
+
import ctypes
|
| 16 |
+
|
| 17 |
import cutlass
|
| 18 |
import cutlass.cute as cute
|
| 19 |
import tvm_ffi
|
| 20 |
from cutlass.cutlass_dsl import JitCompiledFunction
|
| 21 |
|
| 22 |
+
# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
|
| 23 |
+
# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
|
| 24 |
+
# Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes
|
| 25 |
+
# "undefined symbol" errors when loading cached kernels from disk.
|
| 26 |
+
for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):
|
| 27 |
+
if Path(_lib_path).exists():
|
| 28 |
+
ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL)
|
| 29 |
+
|
| 30 |
CompileKeyType: TypeAlias = tuple[Hashable, ...]
|
| 31 |
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
+
_handler = logging.StreamHandler()
|
| 35 |
+
_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
| 36 |
+
logger.addHandler(_handler)
|
| 37 |
+
logger.setLevel(logging.DEBUG)
|
| 38 |
|
| 39 |
|
| 40 |
# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
|
|
|
|
| 75 |
h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
|
| 76 |
|
| 77 |
for src in sorted(cute_root.rglob("*.py")):
|
| 78 |
+
if not src.is_file():
|
| 79 |
+
continue
|
| 80 |
h.update(src.relative_to(cute_root).as_posix().encode())
|
| 81 |
content = src.read_bytes()
|
| 82 |
h.update(len(content).to_bytes(8, "little"))
|
|
|
|
| 122 |
return f"{kind} {self.label}" if self.label else kind
|
| 123 |
|
| 124 |
def __enter__(self) -> "FileLock":
|
| 125 |
+
open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
|
|
|
|
|
|
|
| 126 |
lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
|
| 127 |
|
| 128 |
self._fd = os.open(str(self.lock_path), open_flags)
|
|
|
|
| 186 |
EXPORT_FUNCTION_PREFIX = "func"
|
| 187 |
LOCK_TIMEOUT_SECONDS = 15
|
| 188 |
|
|
|
|
|
|
|
| 189 |
def __init__(self, cache_path: Path):
|
| 190 |
super().__init__()
|
| 191 |
cache_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 214 |
Holds a shared lock during loading to prevent concurrent writes.
|
| 215 |
"""
|
| 216 |
sha256_hex = self._key_to_hash(key)
|
| 217 |
+
obj_path = self.cache_path / f"{sha256_hex}.o"
|
| 218 |
with FileLock(
|
| 219 |
self._lock_path(sha256_hex),
|
| 220 |
exclusive=False,
|
| 221 |
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 222 |
label=sha256_hex,
|
| 223 |
):
|
| 224 |
+
if obj_path.exists():
|
| 225 |
+
logger.debug("Loading compiled function from disk: %s", obj_path)
|
| 226 |
+
m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
|
| 228 |
JITCache.__setitem__(self, key, fn)
|
| 229 |
return True
|
| 230 |
else:
|
| 231 |
+
logger.debug("Cache miss on disk for key hash %s", sha256_hex)
|
|
|
|
|
|
|
| 232 |
return False
|
| 233 |
|
| 234 |
+
def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
|
|
|
|
|
|
|
| 235 |
"""Export a compiled function to persistent storage under exclusive lock."""
|
| 236 |
sha256_hex = self._key_to_hash(key)
|
| 237 |
with FileLock(
|
|
|
|
| 240 |
timeout=self.LOCK_TIMEOUT_SECONDS,
|
| 241 |
label=sha256_hex,
|
| 242 |
):
|
| 243 |
+
obj_path = self.cache_path / f"{sha256_hex}.o"
|
| 244 |
+
if obj_path.exists():
|
| 245 |
# Another process already exported.
|
| 246 |
+
logger.debug("Skipping export, already on disk: %s", obj_path)
|
|
|
|
|
|
|
| 247 |
return
|
| 248 |
+
logger.debug("Exporting compiled function to disk: %s", obj_path)
|
|
|
|
|
|
|
|
|
|
| 249 |
fn.export_to_c(
|
| 250 |
object_file_path=str(obj_path),
|
| 251 |
function_name=self.EXPORT_FUNCTION_PREFIX,
|
| 252 |
)
|
| 253 |
+
logger.debug("Successfully exported compiled function to disk: %s", obj_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
def _key_to_hash(self, key: CompileKeyType) -> str:
|
| 256 |
return hashlib.sha256(pickle.dumps(key)).hexdigest()
|
|
|
|
| 262 |
"""
|
| 263 |
Not only clear the in-memory cache. Also purge persistent compilation cache.
|
| 264 |
"""
|
| 265 |
+
logger.debug("Clearing persistent cache at %s", self.cache_path)
|
|
|
|
|
|
|
| 266 |
super().clear()
|
| 267 |
for child in self.cache_path.iterdir():
|
| 268 |
child.unlink()
|
|
|
|
| 281 |
path = get_cache_path() / _compute_source_fingerprint()
|
| 282 |
if name:
|
| 283 |
path = path / name
|
| 284 |
+
logger.debug("Creating persistent JIT cache at %s", path)
|
|
|
|
|
|
|
| 285 |
return JITPersistentCache(path)
|
| 286 |
else:
|
| 287 |
logger.debug("Persistent cache disabled, using in-memory JIT cache")
|