Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +13 -1
modeling_e1.py
CHANGED
|
@@ -436,7 +436,17 @@ def _try_get_kernels_flash():
|
|
| 436 |
return flash_kernel, flash_kernel_variant
|
| 437 |
|
| 438 |
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
|
| 442 |
def _kernels_flash_forward(
|
|
@@ -546,6 +556,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
|
| 546 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 547 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 548 |
)
|
|
|
|
|
|
|
| 549 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 550 |
if FLASH_KERNEL is not None:
|
| 551 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
|
|
| 436 |
return flash_kernel, flash_kernel_variant
|
| 437 |
|
| 438 |
|
| 439 |
+
_FLASH_KERNELS_LOADED = False
|
| 440 |
+
FLASH_KERNEL = None
|
| 441 |
+
FLASH_KERNEL_VARIANT = None
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _ensure_flash_kernels_loaded():
|
| 445 |
+
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| 446 |
+
if _FLASH_KERNELS_LOADED:
|
| 447 |
+
return
|
| 448 |
+
_FLASH_KERNELS_LOADED = True
|
| 449 |
+
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
| 450 |
|
| 451 |
|
| 452 |
def _kernels_flash_forward(
|
|
|
|
| 556 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 557 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 558 |
)
|
| 559 |
+
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| 560 |
+
_ensure_flash_kernels_loaded()
|
| 561 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 562 |
if FLASH_KERNEL is not None:
|
| 563 |
resolved = AttentionBackend.KERNELS_FLASH
|