DocUA commited on
Commit
e0b7657
·
1 Parent(s): 6379065

feat: Add LlamaFlashAttention2 compatibility alias and eager attention implementation for model loading.

Browse files
Files changed (2) hide show
  1. app.py +7 -0
  2. app_hf.py +8 -0
app.py CHANGED
@@ -10,6 +10,13 @@ import io
10
  import gc
11
  import warnings
12
 
 
 
 
 
 
 
 
13
  # Suppress annoying warnings
14
  warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
15
  warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
 
10
  import gc
11
  import warnings
12
 
13
+ try:
14
+ from transformers.models.llama import modeling_llama as _modeling_llama
15
+ if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"):
16
+ _modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention
17
+ except Exception:
18
+ pass
19
+
20
  # Suppress annoying warnings
21
  warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
22
  warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
app_hf.py CHANGED
@@ -20,6 +20,13 @@ import fitz # PyMuPDF
20
  import io
21
  import gc
22
 
 
 
 
 
 
 
 
23
  # Suppress annoying warnings
24
  warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
25
  warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
@@ -50,6 +57,7 @@ class ModelManager:
50
  model_name,
51
  trust_remote_code=True,
52
  use_safetensors=True,
 
53
  torch_dtype=dtype
54
  )
55
  model.eval()
 
20
  import io
21
  import gc
22
 
23
+ try:
24
+ from transformers.models.llama import modeling_llama as _modeling_llama
25
+ if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"):
26
+ _modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention
27
+ except Exception:
28
+ pass
29
+
30
  # Suppress annoying warnings
31
  warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
32
  warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
 
57
  model_name,
58
  trust_remote_code=True,
59
  use_safetensors=True,
60
+ attn_implementation="eager",
61
  torch_dtype=dtype
62
  )
63
  model.eval()