Kamal-prog-code commited on
Commit
cca03fe
·
1 Parent(s): 284b475

Add fallback for LlamaFlashAttention2 in ensure_llama_flash_attn2 function

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -13,6 +13,18 @@ import numpy as np
13
  import base64
14
  from io import StringIO, BytesIO
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MODEL_NAME = "deepseek-ai/DeepSeek-OCR-2"
17
  MODEL_REVISION = "e6322a289fe5b5218278d276d4e7c58e8103f46a"
18
  DOTS_OCR_MODEL = "rednote-hilab/dots.ocr"
@@ -405,4 +417,4 @@ with gr.Blocks(title="DeepSeek-OCR-2") as demo:
405
  [text_out, md_out, raw_out, img_out, gallery])
406
 
407
  if __name__ == "__main__":
408
- demo.queue(max_size=20).launch(theme=gr.themes.Soft())
 
13
  import base64
14
  from io import StringIO, BytesIO
15
 
16
+ def ensure_llama_flash_attn2():
17
+ try:
18
+ from transformers.models.llama import modeling_llama as llama_mod
19
+ except Exception:
20
+ return
21
+ if not hasattr(llama_mod, "LlamaFlashAttention2"):
22
+ class LlamaFlashAttention2: # fallback shim; not used when attn impl is SDPA
23
+ pass
24
+ llama_mod.LlamaFlashAttention2 = LlamaFlashAttention2
25
+
26
+ ensure_llama_flash_attn2()
27
+
28
  MODEL_NAME = "deepseek-ai/DeepSeek-OCR-2"
29
  MODEL_REVISION = "e6322a289fe5b5218278d276d4e7c58e8103f46a"
30
  DOTS_OCR_MODEL = "rednote-hilab/dots.ocr"
 
417
  [text_out, md_out, raw_out, img_out, gallery])
418
 
419
  if __name__ == "__main__":
420
+ demo.queue(max_size=20).launch(theme=gr.themes.Soft())