Kuangwei Chen commited on
Commit
8ee35eb
·
1 Parent(s): 737e5b1

optimize flash attention import

Browse files
Files changed (1) hide show
  1. modeling_moss_audio_tokenizer.py +33 -11
modeling_moss_audio_tokenizer.py CHANGED
@@ -16,11 +16,13 @@
16
  from __future__ import annotations
17
 
18
  import copy
 
19
  import math
20
  import sys
21
  import types
22
  from contextlib import ExitStack, contextmanager
23
  from dataclasses import dataclass
 
24
  from pathlib import Path
25
  from typing import cast
26
 
@@ -91,13 +93,31 @@ except ImportError:
91
 
92
  logger = logging.get_logger(__name__)
93
 
94
- try:
95
- from flash_attn import flash_attn_varlen_func
96
 
97
- HAS_FLASH_ATTN = True
98
- except ImportError:
99
- flash_attn_varlen_func = None
100
- HAS_FLASH_ATTN = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
 
103
  SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
@@ -328,7 +348,7 @@ class MossAudioTokenizerDecodeSession:
328
  decoder_attention_modules.append(module)
329
 
330
  flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
331
- if use_cuda_graph and HAS_FLASH_ATTN:
332
  for module in decoder_attention_modules:
333
  module._use_flash_kvcache = True
334
  flash_kvcache_attention_modules.append(module)
@@ -1237,7 +1257,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1237
  )
1238
 
1239
  def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool:
1240
- return HAS_FLASH_ATTN and device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
1241
 
1242
  def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype:
1243
  if x.device.type != "cuda":
@@ -1265,7 +1285,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1265
  "(HAS_FLASH_ATTN=%s).",
1266
  x.device,
1267
  backend_dtype,
1268
- HAS_FLASH_ATTN,
1269
  )
1270
  return "sdpa"
1271
 
@@ -1440,6 +1460,7 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1440
  max_seqlen_q: int,
1441
  max_seqlen_k: int,
1442
  ) -> torch.Tensor:
 
1443
  if flash_attn_varlen_func is None:
1444
  raise RuntimeError("flash-attn is not installed.")
1445
  window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
@@ -1525,10 +1546,11 @@ class MossAudioTokenizerMultiheadAttention(StreamingModule):
1525
  return out
1526
 
1527
  def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
1528
- from flash_attn import flash_attn_with_kvcache
1529
-
1530
  if self.context is None:
1531
  raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
 
 
1532
 
1533
  batch_size, chunk_length, _ = x.shape
1534
  q, k_cur, v_cur = self._project_qkv(x)
 
16
  from __future__ import annotations
17
 
18
  import copy
19
+ import importlib
20
  import math
21
  import sys
22
  import types
23
  from contextlib import ExitStack, contextmanager
24
  from dataclasses import dataclass
25
+ from functools import lru_cache
26
  from pathlib import Path
27
  from typing import cast
28
 
 
93
 
94
  logger = logging.get_logger(__name__)
95
 
 
 
96
 
97
+ @lru_cache(maxsize=1)
98
+ def _get_flash_attn_module():
99
+ try:
100
+ return importlib.import_module("flash_attn")
101
+ except Exception:
102
+ return None
103
+
104
+
105
+ def _has_flash_attn() -> bool:
106
+ return _get_flash_attn_module() is not None
107
+
108
+
109
+ def _get_flash_attn_varlen_func():
110
+ flash_attn_module = _get_flash_attn_module()
111
+ if flash_attn_module is None:
112
+ return None
113
+ return getattr(flash_attn_module, "flash_attn_varlen_func", None)
114
+
115
+
116
+ def _get_flash_attn_with_kvcache():
117
+ flash_attn_module = _get_flash_attn_module()
118
+ if flash_attn_module is None:
119
+ return None
120
+ return getattr(flash_attn_module, "flash_attn_with_kvcache", None)
121
 
122
 
123
  SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"sdpa", "flash_attention_2"}
 
348
  decoder_attention_modules.append(module)
349
 
350
  flash_kvcache_attention_modules: list[MossAudioTokenizerMultiheadAttention] = []
351
+ if use_cuda_graph and _has_flash_attn():
352
  for module in decoder_attention_modules:
353
  module._use_flash_kvcache = True
354
  flash_kvcache_attention_modules.append(module)
 
1257
  )
1258
 
1259
  def _supports_flash_attention(self, device: torch.device, dtype: torch.dtype) -> bool:
1260
+ return _has_flash_attn() and device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
1261
 
1262
  def _get_backend_check_dtype(self, x: torch.Tensor) -> torch.dtype:
1263
  if x.device.type != "cuda":
 
1285
  "(HAS_FLASH_ATTN=%s).",
1286
  x.device,
1287
  backend_dtype,
1288
+ _has_flash_attn(),
1289
  )
1290
  return "sdpa"
1291
 
 
1460
  max_seqlen_q: int,
1461
  max_seqlen_k: int,
1462
  ) -> torch.Tensor:
1463
+ flash_attn_varlen_func = _get_flash_attn_varlen_func()
1464
  if flash_attn_varlen_func is None:
1465
  raise RuntimeError("flash-attn is not installed.")
1466
  window_size = (self.context, 0) if (self.context is not None and self.causal) else (-1, -1)
 
1546
  return out
1547
 
1548
  def _forward_streaming_flash_kvcache(self, x: torch.Tensor, state: MHAState) -> torch.Tensor:
1549
+ flash_attn_with_kvcache = _get_flash_attn_with_kvcache()
 
1550
  if self.context is None:
1551
  raise RuntimeError("flash_attn_with_kvcache requires a finite streaming context.")
1552
+ if flash_attn_with_kvcache is None:
1553
+ raise RuntimeError("flash-attn is not installed.")
1554
 
1555
  batch_size, chunk_length, _ = x.shape
1556
  q, k_cur, v_cur = self._project_qkv(x)