danieldk HF Staff commited on
Commit
ef1189b
·
verified ·
1 Parent(s): 4298e26

Build uploaded using `kernels`.

Browse files
build/torch-cuda/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._flash_attn4_c07a63b
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_flash_attn4_c07a63b::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._flash_attn4_c07a63b_dirty
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_flash_attn4_c07a63b_dirty::{op_name}"
build/torch-cuda/cache_utils.py CHANGED
@@ -7,23 +7,34 @@ import pickle
7
  import sys
8
  import tempfile
9
  import time
10
- from distutils.ccompiler import CCompiler, new_compiler
11
  from functools import lru_cache
12
  from getpass import getuser
13
  from pathlib import Path
14
  from typing import Hashable, TypeAlias
15
 
 
 
16
  import cutlass
17
  import cutlass.cute as cute
18
  import tvm_ffi
19
  from cutlass.cutlass_dsl import JitCompiledFunction
20
 
 
 
 
 
 
 
 
 
21
  CompileKeyType: TypeAlias = tuple[Hashable, ...]
22
  CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
23
 
24
  logger = logging.getLogger(__name__)
25
- logger.addHandler(logging.StreamHandler())
26
- logger.setLevel(logging.WARNING)
 
 
27
 
28
 
29
  # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
@@ -64,6 +75,8 @@ def _compute_source_fingerprint() -> str:
64
  h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
65
 
66
  for src in sorted(cute_root.rglob("*.py")):
 
 
67
  h.update(src.relative_to(cute_root).as_posix().encode())
68
  content = src.read_bytes()
69
  h.update(len(content).to_bytes(8, "little"))
@@ -109,9 +122,7 @@ class FileLock:
109
  return f"{kind} {self.label}" if self.label else kind
110
 
111
  def __enter__(self) -> "FileLock":
112
- open_flags = (
113
- os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
114
- )
115
  lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
116
 
117
  self._fd = os.open(str(self.lock_path), open_flags)
@@ -175,8 +186,6 @@ class JITPersistentCache(JITCache):
175
  EXPORT_FUNCTION_PREFIX = "func"
176
  LOCK_TIMEOUT_SECONDS = 15
177
 
178
- _compiler: CCompiler | None = None
179
-
180
  def __init__(self, cache_path: Path):
181
  super().__init__()
182
  cache_path.mkdir(parents=True, exist_ok=True)
@@ -205,32 +214,24 @@ class JITPersistentCache(JITCache):
205
  Holds a shared lock during loading to prevent concurrent writes.
206
  """
207
  sha256_hex = self._key_to_hash(key)
208
- so_path = self.cache_path / f"{sha256_hex}.so"
209
  with FileLock(
210
  self._lock_path(sha256_hex),
211
  exclusive=False,
212
  timeout=self.LOCK_TIMEOUT_SECONDS,
213
  label=sha256_hex,
214
  ):
215
- if so_path.exists():
216
- logger.debug(
217
- "Loading compiled function from disk: %s", so_path
218
- )
219
- m = cute.runtime.load_module(
220
- str(so_path), enable_tvm_ffi=True
221
- )
222
  fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
223
  JITCache.__setitem__(self, key, fn)
224
  return True
225
  else:
226
- logger.debug(
227
- "Cache miss on disk for key hash %s", sha256_hex
228
- )
229
  return False
230
 
231
- def _try_export_to_storage(
232
- self, key: CompileKeyType, fn: JitCompiledFunction
233
- ) -> None:
234
  """Export a compiled function to persistent storage under exclusive lock."""
235
  sha256_hex = self._key_to_hash(key)
236
  with FileLock(
@@ -239,33 +240,17 @@ class JITPersistentCache(JITCache):
239
  timeout=self.LOCK_TIMEOUT_SECONDS,
240
  label=sha256_hex,
241
  ):
242
- so_path = self.cache_path / f"{sha256_hex}.so"
243
- if so_path.exists():
244
  # Another process already exported.
245
- logger.debug(
246
- "Skipping export, already on disk: %s", so_path
247
- )
248
  return
249
- obj_path = self.cache_path / f"{sha256_hex}.o"
250
- logger.debug(
251
- "Exporting compiled function to disk: %s", so_path
252
- )
253
  fn.export_to_c(
254
  object_file_path=str(obj_path),
255
  function_name=self.EXPORT_FUNCTION_PREFIX,
256
  )
257
- # TODO: as of cutedsl 4.4.0, `export_to_c` only supports exporting
258
- # "relocatable" .o files. But tvm_ffi expects "shared library" .so
259
- # files. Link ourselves to workaround.
260
- if JITPersistentCache._compiler is None:
261
- JITPersistentCache._compiler = new_compiler()
262
- JITPersistentCache._compiler.link_shared_object(
263
- [str(obj_path)], str(so_path)
264
- )
265
- obj_path.unlink()
266
- logger.debug(
267
- "Successfully exported compiled function to disk: %s", so_path
268
- )
269
 
270
  def _key_to_hash(self, key: CompileKeyType) -> str:
271
  return hashlib.sha256(pickle.dumps(key)).hexdigest()
@@ -277,9 +262,7 @@ class JITPersistentCache(JITCache):
277
  """
278
  Not only clear the in-memory cache. Also purge persistent compilation cache.
279
  """
280
- logger.debug(
281
- "Clearing persistent cache at %s", self.cache_path
282
- )
283
  super().clear()
284
  for child in self.cache_path.iterdir():
285
  child.unlink()
@@ -298,9 +281,7 @@ def get_jit_cache(name: str | None = None) -> JITCache:
298
  path = get_cache_path() / _compute_source_fingerprint()
299
  if name:
300
  path = path / name
301
- logger.debug(
302
- "Creating persistent JIT cache at %s", path
303
- )
304
  return JITPersistentCache(path)
305
  else:
306
  logger.debug("Persistent cache disabled, using in-memory JIT cache")
 
7
  import sys
8
  import tempfile
9
  import time
 
10
  from functools import lru_cache
11
  from getpass import getuser
12
  from pathlib import Path
13
  from typing import Hashable, TypeAlias
14
 
15
+ import ctypes
16
+
17
  import cutlass
18
  import cutlass.cute as cute
19
  import tvm_ffi
20
  from cutlass.cutlass_dsl import JitCompiledFunction
21
 
22
+ # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
23
+ # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
24
+ # Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes
25
+ # "undefined symbol" errors when loading cached kernels from disk.
26
+ for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):
27
+ if Path(_lib_path).exists():
28
+ ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL)
29
+
30
  CompileKeyType: TypeAlias = tuple[Hashable, ...]
31
  CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
32
 
33
  logger = logging.getLogger(__name__)
34
+ _handler = logging.StreamHandler()
35
+ _handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
36
+ logger.addHandler(_handler)
37
+ logger.setLevel(logging.DEBUG)
38
 
39
 
40
  # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
 
75
  h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
76
 
77
  for src in sorted(cute_root.rglob("*.py")):
78
+ if not src.is_file():
79
+ continue
80
  h.update(src.relative_to(cute_root).as_posix().encode())
81
  content = src.read_bytes()
82
  h.update(len(content).to_bytes(8, "little"))
 
122
  return f"{kind} {self.label}" if self.label else kind
123
 
124
  def __enter__(self) -> "FileLock":
125
+ open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
 
 
126
  lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
127
 
128
  self._fd = os.open(str(self.lock_path), open_flags)
 
186
  EXPORT_FUNCTION_PREFIX = "func"
187
  LOCK_TIMEOUT_SECONDS = 15
188
 
 
 
189
  def __init__(self, cache_path: Path):
190
  super().__init__()
191
  cache_path.mkdir(parents=True, exist_ok=True)
 
214
  Holds a shared lock during loading to prevent concurrent writes.
215
  """
216
  sha256_hex = self._key_to_hash(key)
217
+ obj_path = self.cache_path / f"{sha256_hex}.o"
218
  with FileLock(
219
  self._lock_path(sha256_hex),
220
  exclusive=False,
221
  timeout=self.LOCK_TIMEOUT_SECONDS,
222
  label=sha256_hex,
223
  ):
224
+ if obj_path.exists():
225
+ logger.debug("Loading compiled function from disk: %s", obj_path)
226
+ m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
 
 
 
 
227
  fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
228
  JITCache.__setitem__(self, key, fn)
229
  return True
230
  else:
231
+ logger.debug("Cache miss on disk for key hash %s", sha256_hex)
 
 
232
  return False
233
 
234
+ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
 
 
235
  """Export a compiled function to persistent storage under exclusive lock."""
236
  sha256_hex = self._key_to_hash(key)
237
  with FileLock(
 
240
  timeout=self.LOCK_TIMEOUT_SECONDS,
241
  label=sha256_hex,
242
  ):
243
+ obj_path = self.cache_path / f"{sha256_hex}.o"
244
+ if obj_path.exists():
245
  # Another process already exported.
246
+ logger.debug("Skipping export, already on disk: %s", obj_path)
 
 
247
  return
248
+ logger.debug("Exporting compiled function to disk: %s", obj_path)
 
 
 
249
  fn.export_to_c(
250
  object_file_path=str(obj_path),
251
  function_name=self.EXPORT_FUNCTION_PREFIX,
252
  )
253
+ logger.debug("Successfully exported compiled function to disk: %s", obj_path)
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  def _key_to_hash(self, key: CompileKeyType) -> str:
256
  return hashlib.sha256(pickle.dumps(key)).hexdigest()
 
262
  """
263
  Not only clear the in-memory cache. Also purge persistent compilation cache.
264
  """
265
+ logger.debug("Clearing persistent cache at %s", self.cache_path)
 
 
266
  super().clear()
267
  for child in self.cache_path.iterdir():
268
  child.unlink()
 
281
  path = get_cache_path() / _compute_source_fingerprint()
282
  if name:
283
  path = path / name
284
+ logger.debug("Creating persistent JIT cache at %s", path)
 
 
285
  return JITPersistentCache(path)
286
  else:
287
  logger.debug("Persistent cache disabled, using in-memory JIT cache")