Kuangwei Chen commited on
Commit ·
8ee35eb
1
Parent(s): 737e5b1
optimize flash attention import
Browse files- modeling_moss_audio_tokenizer.py +33 -11
modeling_moss_audio_tokenizer.py
CHANGED
|
@@ -16,11 +16,13 @@
|
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
import copy
|
|
|
|
| 19 |
import math
|
| 20 |
import sys
|
| 21 |
import types
|
| 22 |
from contextlib import ExitStack, contextmanager
|
| 23 |
from dataclasses import dataclass
|
|
|
|
| 24 |
from pathlib import Path
|
| 25 |
from typing import cast
|
| 26 |
|
|
@@ -91,13 +93,31 @@ except ImportError:
|
|
| 91 |
|
| 92 |
logger = logging.get_logger(__name__)
|
| 93 |
|
| 94 |
-
try:
|
| 95 |
-
from flash_attn import flash_attn_varlen_func
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
|
|
@@ -328,7 +348,7 @@ class MossAudioTokenizerDecodeSession:
|
|
| 328 |
decoder_attention_modules.append(module)
|
| 329 |
|
| 330 |
flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
|
| 331 |
-
if use_cuda_graph and
|
| 332 |
for module in decoder_attention_modules:
|
| 333 |
module._use_flash_kvcache = True
|
| 334 |
flash_kvcache_attention_modules.append(module)
|
|
@@ -1237,7 +1257,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1237 |
)
|
| 1238 |
|
| 1239 |
def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool:
|
| 1240 |
-
return
|
| 1241 |
|
| 1242 |
def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype:
|
| 1243 |
if x.device.type != "cuda":
|
|
@@ -1265,7 +1285,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1265 |
"(HAS_FLASH_ATTN=%s).",
|
| 1266 |
x.device,
|
| 1267 |
backend_dtype,
|
| 1268 |
-
|
| 1269 |
)
|
| 1270 |
return "sdpa"
|
| 1271 |
|
|
@@ -1440,6 +1460,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1440 |
max_seqlen_q: int,
|
| 1441 |
max_seqlen_k: int,
|
| 1442 |
) -> torch.Tensor:
|
|
|
|
| 1443 |
if flash_attn_varlen_func is None:
|
| 1444 |
raise RuntimeError("flash-attn is not installed.")
|
| 1445 |
window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
|
|
@@ -1525,10 +1546,11 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
|
|
| 1525 |
return out
|
| 1526 |
|
| 1527 |
def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
if self.context is None:
|
| 1531 |
raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
|
|
|
|
|
|
|
| 1532 |
|
| 1533 |
batch_size, chunk_length, _ = x.shape
|
| 1534 |
q, k_cur, v_cur = self._project_qkv(x)
|
|
|
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
import copy
|
| 19 |
+
import importlib
|
| 20 |
import math
|
| 21 |
import sys
|
| 22 |
import types
|
| 23 |
from contextlib import ExitStack, contextmanager
|
| 24 |
from dataclasses import dataclass
|
| 25 |
+
from functools import lru_cache
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import cast
|
| 28 |
|
|
|
|
| 93 |
|
| 94 |
logger = logging.get_logger(__name__)
|
| 95 |
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
@lru_cache(maxsize=1)
|
| 98 |
+
def _get_flash_attn_module():
|
| 99 |
+
try:
|
| 100 |
+
return importlib.import_module("flash_attn")
|
| 101 |
+
except Exception:
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _has_flash_attn() -> bool:
|
| 106 |
+
return _get_flash_attn_module() is not None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _get_flash_attn_varlen_func():
|
| 110 |
+
flash_attn_module = _get_flash_attn_module()
|
| 111 |
+
if flash_attn_module is None:
|
| 112 |
+
return None
|
| 113 |
+
return getattr(flash_attn_module, "flash_attn_varlen_func", None)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_flash_attn_with_kvcache():
|
| 117 |
+
flash_attn_module = _get_flash_attn_module()
|
| 118 |
+
if flash_attn_module is None:
|
| 119 |
+
return None
|
| 120 |
+
return getattr(flash_attn_module, "flash_attn_with_kvcache", None)
|
| 121 |
|
| 122 |
|
| 123 |
SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
|
|
|
|
| 348 |
decoder_attention_modules.append(module)
|
| 349 |
|
| 350 |
flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
|
| 351 |
+
if use_cuda_graph and _has_flash_attn():
|
| 352 |
for module in decoder_attention_modules:
|
| 353 |
module._use_flash_kvcache = True
|
| 354 |
flash_kvcache_attention_modules.append(module)
|
|
|
|
| 1257 |
)
|
| 1258 |
|
| 1259 |
def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool:
|
| 1260 |
+
return _has_flash_attn() and device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 1261 |
|
| 1262 |
def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype:
|
| 1263 |
if x.device.type != "cuda":
|
|
|
|
| 1285 |
"(HAS_FLASH_ATTN=%s).",
|
| 1286 |
x.device,
|
| 1287 |
backend_dtype,
|
| 1288 |
+
_has_flash_attn(),
|
| 1289 |
)
|
| 1290 |
return "sdpa"
|
| 1291 |
|
|
|
|
| 1460 |
max_seqlen_q: int,
|
| 1461 |
max_seqlen_k: int,
|
| 1462 |
) -> torch.Tensor:
|
| 1463 |
+
flash_attn_varlen_func = _get_flash_attn_varlen_func()
|
| 1464 |
if flash_attn_varlen_func is None:
|
| 1465 |
raise RuntimeError("flash-attn is not installed.")
|
| 1466 |
window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
|
|
|
|
| 1546 |
return out
|
| 1547 |
|
| 1548 |
def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
|
| 1549 |
+
flash_attn_with_kvcache = _get_flash_attn_with_kvcache()
|
|
|
|
| 1550 |
if self.context is None:
|
| 1551 |
raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
|
| 1552 |
+
if flash_attn_with_kvcache is None:
|
| 1553 |
+
raise RuntimeError("flash-attn is not installed.")
|
| 1554 |
|
| 1555 |
batch_size, chunk_length, _ = x.shape
|
| 1556 |
q, k_cur, v_cur = self._project_qkv(x)
|