upprize commited on
Commit
700ddbf
·
1 Parent(s): 8a06d9f
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +13 -3
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  models:
12
  - rednote-hilab/DotsOCR
13
  tags:
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Multilingual document layout parsing with OCR
12
  models:
13
  - rednote-hilab/DotsOCR
14
  tags:
app.py CHANGED
@@ -27,13 +27,23 @@ def load_model():
27
  global model, processor
28
  if model is None:
29
  print(f"Loading model weights from {MODEL_PATH}...")
 
 
 
 
 
 
 
 
 
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
  MODEL_PATH,
32
- attn_implementation="flash_attention_2",
33
- torch_dtype=torch.bfloat16,
34
  device_map="auto",
35
  trust_remote_code=True,
36
- token=HF_TOKEN
 
37
  )
38
  print("Model loaded successfully.")
39
 
 
27
  global model, processor
28
  if model is None:
29
  print(f"Loading model weights from {MODEL_PATH}...")
30
+
31
+ # Try to use FlashAttention2 if available, otherwise use default attention
32
+ try:
33
+ import flash_attn
34
+ attn_implementation = "flash_attention_2"
35
+ print("Using FlashAttention2 for faster inference")
36
+ except ImportError:
37
+ attn_implementation = "eager"
38
+ print("FlashAttention2 not available, using default attention")
39
+
40
  model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_PATH,
42
+ dtype=torch.bfloat16,
 
43
  device_map="auto",
44
  trust_remote_code=True,
45
+ token=HF_TOKEN,
46
+ attn_implementation=attn_implementation
47
  )
48
  print("Model loaded successfully.")
49