lhallee commited on
Commit
7649515
·
verified ·
1 Parent(s): a94df5d

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +13 -1
modeling_e1.py CHANGED
@@ -436,7 +436,17 @@ def _try_get_kernels_flash():
436
  return flash_kernel, flash_kernel_variant
437
 
438
 
439
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
 
 
 
 
 
 
 
 
 
 
440
 
441
 
442
  def _kernels_flash_forward(
@@ -546,6 +556,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
546
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
547
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
548
  )
 
 
549
  if requested_backend == AttentionBackend.AUTO.value:
550
  if FLASH_KERNEL is not None:
551
  resolved = AttentionBackend.KERNELS_FLASH
 
436
  return flash_kernel, flash_kernel_variant
437
 
438
 
439
+ _FLASH_KERNELS_LOADED = False
440
+ FLASH_KERNEL = None
441
+ FLASH_KERNEL_VARIANT = None
442
+
443
+
444
+ def _ensure_flash_kernels_loaded():
445
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
446
+ if _FLASH_KERNELS_LOADED:
447
+ return
448
+ _FLASH_KERNELS_LOADED = True
449
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
450
 
451
 
452
  def _kernels_flash_forward(
 
556
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
557
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
558
  )
559
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
560
+ _ensure_flash_kernels_loaded()
561
  if requested_backend == AttentionBackend.AUTO.value:
562
  if FLASH_KERNEL is not None:
563
  resolved = AttentionBackend.KERNELS_FLASH